You've already forked aws-iam-anywhere-refresher
Initial commit and release
This commit is contained in:
336
aws_signing_helper/serve.go
Normal file
336
aws_signing_helper/serve.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package aws_signing_helper
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/arn"
|
||||
)
|
||||
|
||||
const DefaultPort = 9911
|
||||
const LocalHostAddress = "127.0.0.1"
|
||||
|
||||
var RefreshTime = time.Minute * time.Duration(5)
|
||||
|
||||
type RefreshableCred struct {
|
||||
AccessKeyId string
|
||||
SecretAccessKey string
|
||||
Token string
|
||||
Code string
|
||||
Type string
|
||||
Expiration time.Time
|
||||
LastUpdated time.Time
|
||||
}
|
||||
|
||||
type Endpoint struct {
|
||||
PortNum int
|
||||
Server *http.Server
|
||||
TmpCred RefreshableCred
|
||||
}
|
||||
|
||||
type SessionToken struct {
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
const TOKEN_RESOURCE_PATH = "/latest/api/token"
|
||||
const SECURITY_CREDENTIALS_RESOURCE_PATH = "/latest/meta-data/iam/security-credentials/"
|
||||
|
||||
const EC2_METADATA_TOKEN_HEADER = "x-aws-ec2-metadata-token"
|
||||
const EC2_METADATA_TOKEN_TTL_HEADER = "x-aws-ec2-metadata-token-ttl-seconds"
|
||||
const DEFAULT_TOKEN_TTL_SECONDS = "21600"
|
||||
|
||||
const X_FORWARDED_FOR_HEADER = "X-Forwarded-For"
|
||||
|
||||
const REFRESHABLE_CRED_TYPE = "AWS-HMAC"
|
||||
const REFRESHABLE_CRED_CODE = "Success"
|
||||
|
||||
const MAX_TOKENS = 256
|
||||
|
||||
var mutex sync.Mutex
|
||||
var tokenMap = make(map[string]time.Time)
|
||||
|
||||
// Generates a random string with the specified length
|
||||
func GenerateToken(length int) (string, error) {
|
||||
if length < 0 || length >= 128 {
|
||||
msg := "invalid token length"
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
randomBytes := make([]byte, 128)
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(randomBytes)[:length], nil
|
||||
}
|
||||
|
||||
// Removes the token that expires the earliest
|
||||
func InsertToken(token string, expirationTime time.Time) error {
|
||||
mutex.Lock()
|
||||
if len(tokenMap) == MAX_TOKENS {
|
||||
earliestExpirationTime := time.Unix(1<<63-1, 0)
|
||||
var earliestExpiringToken string
|
||||
for key, value := range tokenMap {
|
||||
if earliestExpirationTime.After(value) {
|
||||
earliestExpiringToken = key
|
||||
earliestExpirationTime = value
|
||||
}
|
||||
}
|
||||
|
||||
delete(tokenMap, earliestExpiringToken)
|
||||
log.Printf("evicting earliest expiring token: %s", earliestExpiringToken)
|
||||
}
|
||||
tokenMap[token] = expirationTime
|
||||
mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function that checks to see whether the token provided in the request is valid
|
||||
func CheckValidToken(w http.ResponseWriter, r *http.Request) error {
|
||||
token := r.Header.Get(EC2_METADATA_TOKEN_HEADER)
|
||||
if token == "" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
msg := "no token provided"
|
||||
io.WriteString(w, msg)
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
expiration, ok := tokenMap[token]
|
||||
mutex.Unlock()
|
||||
if ok {
|
||||
if time.Now().After(expiration) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
msg := "invalid token provided"
|
||||
io.WriteString(w, msg)
|
||||
return errors.New(msg)
|
||||
}
|
||||
} else {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
msg := "invalid token provided"
|
||||
io.WriteString(w, msg)
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function that finds a token's TTL in seconds
|
||||
func FindTokenTTLSeconds(r *http.Request) (string, error) {
|
||||
token := r.Header.Get(EC2_METADATA_TOKEN_HEADER)
|
||||
if token == "" {
|
||||
msg := "no token provided"
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
expiration, ok := tokenMap[token]
|
||||
mutex.Unlock()
|
||||
if ok {
|
||||
tokenTTLFloat := expiration.Sub(time.Now()).Seconds()
|
||||
tokenTTLInt64 := int64(tokenTTLFloat)
|
||||
return strconv.FormatInt(tokenTTLInt64, 10), nil
|
||||
} else {
|
||||
msg := "invalid token provided"
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (http.HandlerFunc, http.HandlerFunc, http.HandlerFunc) {
|
||||
// Handles PUT requests to /latest/api/token/
|
||||
putTokenHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PUT" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for the presence of the X-Forwarded-For header
|
||||
xForwardedForHeader := r.Header.Get(X_FORWARDED_FOR_HEADER) // canonicalized headers are used (casing doesn't matter)
|
||||
if xForwardedForHeader != "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(w, "unable to process requests with X-Forwarded-For header")
|
||||
return
|
||||
}
|
||||
|
||||
// Obtain the token TTL
|
||||
tokenTTLStr := r.Header.Get(EC2_METADATA_TOKEN_TTL_HEADER)
|
||||
if tokenTTLStr == "" {
|
||||
tokenTTLStr = DEFAULT_TOKEN_TTL_SECONDS
|
||||
}
|
||||
tokenTTL, err := strconv.Atoi(tokenTTLStr)
|
||||
if err != nil || tokenTTL < 1 || tokenTTL > 21600 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(w, "invalid token TTL")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate token and insert it into map
|
||||
token, err := GenerateToken(100)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "unable to generate token")
|
||||
return
|
||||
}
|
||||
expirationTime := time.Now().Add(time.Second * time.Duration(tokenTTL))
|
||||
InsertToken(token, expirationTime)
|
||||
|
||||
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTLStr)
|
||||
io.WriteString(w, token) // nosemgrep
|
||||
}
|
||||
|
||||
// Handles requests to /latest/meta-data/iam/security-credentials/
|
||||
getRoleNameHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
err := CheckValidToken(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tokenTTL, err := FindTokenTTLSeconds(r)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
|
||||
io.WriteString(w, roleName) // nosemgrep
|
||||
}
|
||||
|
||||
// Handles GET requests to /latest/meta-data/iam/security-credentials/<ROLE_NAME>
|
||||
getCredentialsHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
err := CheckValidToken(w, r)
|
||||
if err != nil {
|
||||
log.Printf("Token validation received error: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
var nextRefreshTime = cred.Expiration.Add(-RefreshTime)
|
||||
if time.Until(nextRefreshTime) < RefreshTime {
|
||||
if Debug {
|
||||
log.Println("Generating credentials")
|
||||
}
|
||||
credentialProcessOutput, gcErr := GenerateCredentials(opts, signer, signatureAlgorithm)
|
||||
if gcErr != nil {
|
||||
log.Printf("Error generating credentials: %s\n", gcErr)
|
||||
}
|
||||
cred.AccessKeyId = credentialProcessOutput.AccessKeyId
|
||||
cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
|
||||
cred.Token = credentialProcessOutput.SessionToken
|
||||
cred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
|
||||
cred.Code = REFRESHABLE_CRED_CODE
|
||||
cred.LastUpdated = time.Now()
|
||||
cred.Type = REFRESHABLE_CRED_TYPE
|
||||
err := json.NewEncoder(w).Encode(cred)
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "failed to encode credentials")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if Debug {
|
||||
log.Println("Using previously obtained credentials")
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(cred)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "failed to encode credentials")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tokenTTL, err := FindTokenTTLSeconds(r)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
|
||||
}
|
||||
|
||||
return putTokenHandler, getRoleNameHandler, getCredentialsHandler
|
||||
}
|
||||
|
||||
func Serve(port int, credentialsOptions CredentialsOpts) {
|
||||
var refreshableCred = RefreshableCred{}
|
||||
|
||||
roleArn, err := arn.Parse(credentialsOptions.RoleArn)
|
||||
if err != nil {
|
||||
log.Println("invalid role ARN")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
signer, signatureAlgorithm, err := GetSigner(&credentialsOptions)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer signer.Close()
|
||||
|
||||
credentialProcessOutput, _ := GenerateCredentials(&credentialsOptions, signer, signatureAlgorithm)
|
||||
refreshableCred.AccessKeyId = credentialProcessOutput.AccessKeyId
|
||||
refreshableCred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
|
||||
refreshableCred.Token = credentialProcessOutput.SessionToken
|
||||
refreshableCred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
|
||||
refreshableCred.Code = REFRESHABLE_CRED_CODE
|
||||
refreshableCred.LastUpdated = time.Now()
|
||||
refreshableCred.Type = REFRESHABLE_CRED_TYPE
|
||||
endpoint := &Endpoint{PortNum: port, TmpCred: refreshableCred}
|
||||
endpoint.Server = &http.Server{}
|
||||
roleResourceParts := strings.Split(roleArn.Resource, "/")
|
||||
roleName := roleResourceParts[len(roleResourceParts)-1] // Find role name without path
|
||||
putTokenHandler, getRoleNameHandler, getCredentialsHandler := AllIssuesHandlers(&endpoint.TmpCred, roleName, &credentialsOptions, signer, signatureAlgorithm)
|
||||
|
||||
http.HandleFunc(TOKEN_RESOURCE_PATH, putTokenHandler)
|
||||
http.HandleFunc(SECURITY_CREDENTIALS_RESOURCE_PATH, getRoleNameHandler)
|
||||
http.HandleFunc(SECURITY_CREDENTIALS_RESOURCE_PATH+roleName, getCredentialsHandler)
|
||||
|
||||
// Background thread that cleans up expired tokens
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
curTime := time.Now()
|
||||
mutex.Lock()
|
||||
for key, value := range tokenMap {
|
||||
if curTime.After(value) {
|
||||
delete(tokenMap, key)
|
||||
log.Printf("removed expired token: %s", key)
|
||||
}
|
||||
}
|
||||
mutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the credentials endpoint
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", LocalHostAddress, endpoint.PortNum))
|
||||
if err != nil {
|
||||
log.Println("failed to create listener")
|
||||
os.Exit(1)
|
||||
}
|
||||
endpoint.PortNum = listener.Addr().(*net.TCPAddr).Port
|
||||
log.Println("Local server started on port:", endpoint.PortNum)
|
||||
log.Println("Make it available to the sdk by running:")
|
||||
log.Printf("export AWS_EC2_METADATA_SERVICE_ENDPOINT=http://%s:%d/", LocalHostAddress, endpoint.PortNum)
|
||||
if err := endpoint.Server.Serve(listener); err != nil {
|
||||
log.Println("Httpserver: ListenAndServe() error")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user