Initial commit and release

This commit is contained in:
2024-07-31 10:08:48 -04:00
commit 360e1cf241
15 changed files with 3053 additions and 0 deletions

1
.dockerignore Normal file
View File

@@ -0,0 +1 @@
.idea/

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
aws-iam-anywhere-refresher
.idea/

20
Dockerfile Normal file
View File

@@ -0,0 +1,20 @@
FROM siteworxpro/golang:1.22.5 AS build
ENV GOPRIVATE=git.s.int
ENV GOPROXY=direct
WORKDIR /app
ADD . .
RUN go mod tidy && CGO_ENABLED=0 GOOS=linux GOARCH=amd64 GO111MODULE=on GOOS=linux go build -o /app/aws-iam-anywhere-refresher
FROM ubuntu AS runtime
WORKDIR /app
COPY --from=build /app/aws-iam-anywhere-refresher aws-iam-anywhere-refresher
RUN apt update && apt install -yqq ca-certificates
ENTRYPOINT ["/app/aws-iam-anywhere-refresher"]

143
README.md Normal file
View File

@@ -0,0 +1,143 @@
# AWS IAM Roles Anywhere Refresher
## Setup
[AWS IAM Roles Anywhere](https://docs.aws.amazon.com/rolesanywhere/latest/userguide/introduction.html)
If you are running workloads outside of AWS it's recommended that you only use short lived IAM credentials.
Because those credentials expire they need to be refreshed on a schedule.
This image runs in a kubernetes cronjob and will create and save new IAM credentials in a secret.
*This container is not designed to run outside of kubernetes!*
## Docker hub
[docker image](https://hub.docker.com/repository/docker/siteworxpro/aws-iam-anywhere/general)
[docker image](https://hub.docker.com/repository/docker/siteworxpro/aws-iam-anywhere/general)
## Environment Variables
- `SECRET`: the name of the secret containing the aws credentials (default=aws-credentials)
- `RESTART_DEPLOYMENTS` : restart deployments on success (default=false)
- `SESSION_DURATION` : how long credentials requested will be valid (default=900)
- `NAMESPACE` ***required*** : the namespace your cron pod is in
- `ROLE_ARN` ***required*** : the role arn to assume
- `PROFILE_ARN` ***required*** : the aim anywhere profile arn
- `TRUSTED_ANCHOR_ARN` ***required*** : the trusted anchor arn
- `PRIVATE_KEY` ***required*** : iam private key base64 encoded
- `CERTIFICATE` ***required*** : iam certificate base64 encoded
```yaml
apiVersion: batch/v1
kind: CronJob
metadata:
name: aws-iam-anywhere
spec:
concurrencyPolicy: Forbid
failedJobsHistoryLimit: 1
jobTemplate:
spec:
template:
spec:
serviceAccountName: aws-iam-anywhere-refresher
restartPolicy: Never
containers:
- name: refresher
image: siteworxpro/aws-iam-anywhere
imagePullPolicy: Always
env:
- name: NAMESPACE
value: default
- name: SECRET
value: aws-credentials
- name: ROLE_ARN
value: arn:aws:iam::12345:role/my-role
- name: PROFILE_ARN
value: arn:aws:rolesanywhere:us-east-1:12345:profile/bdf23662-32fe-482f-98f4-f10ba6afacd8
- name: TRUSTED_ANCHOR_ARN
value: arn:aws:rolesanywhere:us-east-1:3123451:trust-anchor/23692607-2a1e-468d-80d4-dc78ce9d9b1a
- name: CERTIFICATE
value: LS0...S0K
- name: PRIVATE_KEY
value: LS0t...S0K
schedule: 00 * * * *
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: aws-iam-anywhere-refresher
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: aws-iam-anywhere-role
namespace: aws-iam-anywhere
rules:
- verbs:
- list
- update
resources:
- deployments
apiGroups:
- apps
- verbs:
- create
- update
- get
resources:
- secrets
apiGroups:
-
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: aws-iam-anywhere
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: Role
name: aws-iam-anywhere-role
subjects:
- kind: ServiceAccount
name: aws-iam-anywhere-refresher
namespace: default
```
resulting secret
```yaml
apiVersion: v1
kind: Secret
metadata:
labels:
managed-by: aws-iam-anywhere-refresher
name: aws-credentials
namespace: default
data:
AWS_ACCESS_KEY_ID: QVN....lE=
AWS_SECRET_ACCESS_KEY: WT...Qw==
AWS_SESSION_TOKEN: SVFv...VzPQ==
```
## Restarting Deployments
You can optionally restart your deployments if needed. If this isn't needed you can exclude the permissions in the role
above.
The process will list all deployments with the label `iam-role-type=aws-iam-anywhere` and restart them.
Be sure, if needed to avoid downtime, to configure your deployments readiness probes.
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: aws-iam-anywhere
namespace: aws-iam-anywhere
labels:
iam-role-type: aws-iam-anywhere
```

View File

@@ -0,0 +1,128 @@
package aws_signing_helper
import (
"crypto/tls"
"encoding/base64"
"errors"
"github.com/aws/rolesanywhere-credential-helper/rolesanywhere"
"log"
"net/http"
"runtime"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
)
type CredentialsOpts struct {
PrivateKeyId string
CertificateId string
CertificateBundleId string
CertIdentifier CertIdentifier
RoleArn string
ProfileArnStr string
TrustAnchorArnStr string
SessionDuration int
Region string
Endpoint string
NoVerifySSL bool
WithProxy bool
Debug bool
Version string
LibPkcs11 string
ReusePin bool
}
// Function to create session and generate credentials
func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) {
// Assign values to region and endpoint if they haven't already been assigned
trustAnchorArn, err := arn.Parse(opts.TrustAnchorArnStr)
if err != nil {
return CredentialProcessOutput{}, err
}
profileArn, err := arn.Parse(opts.ProfileArnStr)
if err != nil {
return CredentialProcessOutput{}, err
}
if trustAnchorArn.Region != profileArn.Region {
return CredentialProcessOutput{}, errors.New("trust anchor and profile regions don't match")
}
if opts.Region == "" {
opts.Region = trustAnchorArn.Region
}
mySession := session.Must(session.NewSession())
var logLevel aws.LogLevelType
if Debug {
logLevel = aws.LogDebug
} else {
logLevel = aws.LogOff
}
var tr *http.Transport
if opts.WithProxy {
tr = &http.Transport{
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL},
Proxy: http.ProxyFromEnvironment,
}
} else {
tr = &http.Transport{
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL},
}
}
client := &http.Client{Transport: tr}
config := aws.NewConfig().WithRegion(opts.Region).WithHTTPClient(client).WithLogLevel(logLevel)
if opts.Endpoint != "" {
config.WithEndpoint(opts.Endpoint)
}
rolesAnywhereClient := rolesanywhere.New(mySession, config)
rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler")
rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: request.MakeAddToUserAgentHandler("CredHelper", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)})
rolesAnywhereClient.Handlers.Sign.Clear()
certificate, err := signer.Certificate()
if err != nil {
return CredentialProcessOutput{}, errors.New("unable to find certificate")
}
certificateChain, err := signer.CertificateChain()
if err != nil {
// If the chain couldn't be found, don't include it in the request
if Debug {
log.Println(err)
}
}
rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: CreateRequestSignFunction(signer, signatureAlgorithm, certificate, certificateChain)})
certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw)
durationSeconds := int64(opts.SessionDuration)
createSessionRequest := rolesanywhere.CreateSessionInput{
Cert: &certificateStr,
ProfileArn: &opts.ProfileArnStr,
TrustAnchorArn: &opts.TrustAnchorArnStr,
DurationSeconds: &(durationSeconds),
InstanceProperties: nil,
RoleArn: &opts.RoleArn,
SessionName: nil,
}
output, err := rolesAnywhereClient.CreateSession(&createSessionRequest)
if err != nil {
return CredentialProcessOutput{}, err
}
if len(output.CredentialSet) == 0 {
msg := "unable to obtain temporary security credentials from CreateSession"
return CredentialProcessOutput{}, errors.New(msg)
}
credentials := output.CredentialSet[0].Credentials
credentialProcessOutput := CredentialProcessOutput{
Version: 1,
AccessKeyId: *credentials.AccessKeyId,
SecretAccessKey: *credentials.SecretAccessKey,
SessionToken: *credentials.SessionToken,
Expiration: *credentials.Expiration,
}
return credentialProcessOutput, nil
}

View File

@@ -0,0 +1,150 @@
package aws_signing_helper
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"errors"
"io"
"log"
"os"
)
type FileSystemSigner struct {
bundlePath string
certPath string
isPkcs12 bool
privateKeyPath string
}
func (fileSystemSigner *FileSystemSigner) Public() crypto.PublicKey {
privateKey, _, _ := fileSystemSigner.readCertFiles()
{
privateKey, ok := privateKey.(ecdsa.PrivateKey)
if ok {
return &privateKey.PublicKey
}
}
{
privateKey, ok := privateKey.(rsa.PrivateKey)
if ok {
return &privateKey.PublicKey
}
}
return nil
}
func (fileSystemSigner *FileSystemSigner) Close() {}
func (fileSystemSigner *FileSystemSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
privateKey, _, _ := fileSystemSigner.readCertFiles()
var hash []byte
switch opts.HashFunc() {
case crypto.SHA256:
sum := sha256.Sum256(digest)
hash = sum[:]
case crypto.SHA384:
sum := sha512.Sum384(digest)
hash = sum[:]
case crypto.SHA512:
sum := sha512.Sum512(digest)
hash = sum[:]
default:
return nil, ErrUnsupportedHash
}
ecdsaPrivateKey, ok := privateKey.(ecdsa.PrivateKey)
if ok {
sig, err := ecdsa.SignASN1(rand, &ecdsaPrivateKey, hash[:])
if err == nil {
return sig, nil
}
}
rsaPrivateKey, ok := privateKey.(rsa.PrivateKey)
if ok {
sig, err := rsa.SignPKCS1v15(rand, &rsaPrivateKey, opts.HashFunc(), hash[:])
if err == nil {
return sig, nil
}
}
log.Println("unsupported algorithm")
return nil, errors.New("unsupported algorithm")
}
func (fileSystemSigner *FileSystemSigner) Certificate() (*x509.Certificate, error) {
_, cert, _ := fileSystemSigner.readCertFiles()
return cert, nil
}
func (fileSystemSigner *FileSystemSigner) CertificateChain() ([]*x509.Certificate, error) {
_, _, certChain := fileSystemSigner.readCertFiles()
return certChain, nil
}
// GetFileSystemSigner returns a FileSystemSigner, that signs a payload using the private key passed in
func GetFileSystemSigner(privateKeyPath string, certPath string, bundlePath string, isPkcs12 bool) (signer Signer, signingAlgorithm string, err error) {
fsSigner := &FileSystemSigner{bundlePath: bundlePath, certPath: certPath, isPkcs12: isPkcs12, privateKeyPath: privateKeyPath}
privateKey, _, _ := fsSigner.readCertFiles()
// Find the signing algorithm
_, isRsaKey := privateKey.(rsa.PrivateKey)
if isRsaKey {
signingAlgorithm = aws4_x509_rsa_sha256
}
_, isEcKey := privateKey.(ecdsa.PrivateKey)
if isEcKey {
signingAlgorithm = aws4_x509_ecdsa_sha256
}
if signingAlgorithm == "" {
log.Println("unsupported algorithm")
return nil, "", errors.New("unsupported algorithm")
}
return fsSigner, signingAlgorithm, nil
}
func (fileSystemSigner *FileSystemSigner) readCertFiles() (crypto.PrivateKey, *x509.Certificate, []*x509.Certificate) {
if fileSystemSigner.isPkcs12 {
chain, privateKey, err := ReadPKCS12Data(fileSystemSigner.certPath)
if err != nil {
log.Printf("Failed to read PKCS12 certificate: %s\n", err)
os.Exit(1)
}
return privateKey, chain[0], chain
} else {
privateKey, err := ReadPrivateKeyData(fileSystemSigner.privateKeyPath)
if err != nil {
log.Printf("Failed to read private key: %s\n", err)
os.Exit(1)
}
var chain []*x509.Certificate
if fileSystemSigner.bundlePath != "" {
chain, err = GetCertChain(fileSystemSigner.bundlePath)
if err != nil {
privateKey = nil
log.Printf("Failed to read certificate bundle: %s\n", err)
os.Exit(1)
}
}
var cert *x509.Certificate
if fileSystemSigner.certPath != "" {
_, cert, err = ReadCertificateData(fileSystemSigner.certPath)
if err != nil {
privateKey = nil
log.Printf("Failed to read certificate: %s\n", err)
os.Exit(1)
}
} else if len(chain) > 0 {
cert = chain[0]
} else {
log.Println("No certificate path or certificate bundle path provided")
os.Exit(1)
}
return privateKey, cert, chain
}
}

336
aws_signing_helper/serve.go Normal file
View File

@@ -0,0 +1,336 @@
package aws_signing_helper
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/arn"
)
const DefaultPort = 9911
const LocalHostAddress = "127.0.0.1"
var RefreshTime = time.Minute * time.Duration(5)
type RefreshableCred struct {
AccessKeyId string
SecretAccessKey string
Token string
Code string
Type string
Expiration time.Time
LastUpdated time.Time
}
type Endpoint struct {
PortNum int
Server *http.Server
TmpCred RefreshableCred
}
type SessionToken struct {
Expiration time.Time
}
const TOKEN_RESOURCE_PATH = "/latest/api/token"
const SECURITY_CREDENTIALS_RESOURCE_PATH = "/latest/meta-data/iam/security-credentials/"
const EC2_METADATA_TOKEN_HEADER = "x-aws-ec2-metadata-token"
const EC2_METADATA_TOKEN_TTL_HEADER = "x-aws-ec2-metadata-token-ttl-seconds"
const DEFAULT_TOKEN_TTL_SECONDS = "21600"
const X_FORWARDED_FOR_HEADER = "X-Forwarded-For"
const REFRESHABLE_CRED_TYPE = "AWS-HMAC"
const REFRESHABLE_CRED_CODE = "Success"
const MAX_TOKENS = 256
var mutex sync.Mutex
var tokenMap = make(map[string]time.Time)
// Generates a random string with the specified length
func GenerateToken(length int) (string, error) {
if length < 0 || length >= 128 {
msg := "invalid token length"
return "", errors.New(msg)
}
randomBytes := make([]byte, 128)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(randomBytes)[:length], nil
}
// Removes the token that expires the earliest
func InsertToken(token string, expirationTime time.Time) error {
mutex.Lock()
if len(tokenMap) == MAX_TOKENS {
earliestExpirationTime := time.Unix(1<<63-1, 0)
var earliestExpiringToken string
for key, value := range tokenMap {
if earliestExpirationTime.After(value) {
earliestExpiringToken = key
earliestExpirationTime = value
}
}
delete(tokenMap, earliestExpiringToken)
log.Printf("evicting earliest expiring token: %s", earliestExpiringToken)
}
tokenMap[token] = expirationTime
mutex.Unlock()
return nil
}
// Helper function that checks to see whether the token provided in the request is valid
func CheckValidToken(w http.ResponseWriter, r *http.Request) error {
token := r.Header.Get(EC2_METADATA_TOKEN_HEADER)
if token == "" {
w.WriteHeader(http.StatusUnauthorized)
msg := "no token provided"
io.WriteString(w, msg)
return errors.New(msg)
}
mutex.Lock()
expiration, ok := tokenMap[token]
mutex.Unlock()
if ok {
if time.Now().After(expiration) {
w.WriteHeader(http.StatusUnauthorized)
msg := "invalid token provided"
io.WriteString(w, msg)
return errors.New(msg)
}
} else {
w.WriteHeader(http.StatusUnauthorized)
msg := "invalid token provided"
io.WriteString(w, msg)
return errors.New(msg)
}
return nil
}
// Helper function that finds a token's TTL in seconds
func FindTokenTTLSeconds(r *http.Request) (string, error) {
token := r.Header.Get(EC2_METADATA_TOKEN_HEADER)
if token == "" {
msg := "no token provided"
return "", errors.New(msg)
}
mutex.Lock()
expiration, ok := tokenMap[token]
mutex.Unlock()
if ok {
tokenTTLFloat := expiration.Sub(time.Now()).Seconds()
tokenTTLInt64 := int64(tokenTTLFloat)
return strconv.FormatInt(tokenTTLInt64, 10), nil
} else {
msg := "invalid token provided"
return "", errors.New(msg)
}
}
func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (http.HandlerFunc, http.HandlerFunc, http.HandlerFunc) {
// Handles PUT requests to /latest/api/token/
putTokenHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Check for the presence of the X-Forwarded-For header
xForwardedForHeader := r.Header.Get(X_FORWARDED_FOR_HEADER) // canonicalized headers are used (casing doesn't matter)
if xForwardedForHeader != "" {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "unable to process requests with X-Forwarded-For header")
return
}
// Obtain the token TTL
tokenTTLStr := r.Header.Get(EC2_METADATA_TOKEN_TTL_HEADER)
if tokenTTLStr == "" {
tokenTTLStr = DEFAULT_TOKEN_TTL_SECONDS
}
tokenTTL, err := strconv.Atoi(tokenTTLStr)
if err != nil || tokenTTL < 1 || tokenTTL > 21600 {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "invalid token TTL")
return
}
// Generate token and insert it into map
token, err := GenerateToken(100)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "unable to generate token")
return
}
expirationTime := time.Now().Add(time.Second * time.Duration(tokenTTL))
InsertToken(token, expirationTime)
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTLStr)
io.WriteString(w, token) // nosemgrep
}
// Handles requests to /latest/meta-data/iam/security-credentials/
getRoleNameHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
err := CheckValidToken(w, r)
if err != nil {
return
}
tokenTTL, err := FindTokenTTLSeconds(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
io.WriteString(w, roleName) // nosemgrep
}
// Handles GET requests to /latest/meta-data/iam/security-credentials/<ROLE_NAME>
getCredentialsHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
err := CheckValidToken(w, r)
if err != nil {
log.Printf("Token validation received error: %s\n", err)
return
}
var nextRefreshTime = cred.Expiration.Add(-RefreshTime)
if time.Until(nextRefreshTime) < RefreshTime {
if Debug {
log.Println("Generating credentials")
}
credentialProcessOutput, gcErr := GenerateCredentials(opts, signer, signatureAlgorithm)
if gcErr != nil {
log.Printf("Error generating credentials: %s\n", gcErr)
}
cred.AccessKeyId = credentialProcessOutput.AccessKeyId
cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
cred.Token = credentialProcessOutput.SessionToken
cred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
cred.Code = REFRESHABLE_CRED_CODE
cred.LastUpdated = time.Now()
cred.Type = REFRESHABLE_CRED_TYPE
err := json.NewEncoder(w).Encode(cred)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "failed to encode credentials")
return
}
} else {
if Debug {
log.Println("Using previously obtained credentials")
}
err := json.NewEncoder(w).Encode(cred)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "failed to encode credentials")
return
}
}
tokenTTL, err := FindTokenTTLSeconds(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set(EC2_METADATA_TOKEN_TTL_HEADER, tokenTTL)
}
return putTokenHandler, getRoleNameHandler, getCredentialsHandler
}
func Serve(port int, credentialsOptions CredentialsOpts) {
var refreshableCred = RefreshableCred{}
roleArn, err := arn.Parse(credentialsOptions.RoleArn)
if err != nil {
log.Println("invalid role ARN")
os.Exit(1)
}
signer, signatureAlgorithm, err := GetSigner(&credentialsOptions)
if err != nil {
log.Println(err)
os.Exit(1)
}
defer signer.Close()
credentialProcessOutput, _ := GenerateCredentials(&credentialsOptions, signer, signatureAlgorithm)
refreshableCred.AccessKeyId = credentialProcessOutput.AccessKeyId
refreshableCred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
refreshableCred.Token = credentialProcessOutput.SessionToken
refreshableCred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
refreshableCred.Code = REFRESHABLE_CRED_CODE
refreshableCred.LastUpdated = time.Now()
refreshableCred.Type = REFRESHABLE_CRED_TYPE
endpoint := &Endpoint{PortNum: port, TmpCred: refreshableCred}
endpoint.Server = &http.Server{}
roleResourceParts := strings.Split(roleArn.Resource, "/")
roleName := roleResourceParts[len(roleResourceParts)-1] // Find role name without path
putTokenHandler, getRoleNameHandler, getCredentialsHandler := AllIssuesHandlers(&endpoint.TmpCred, roleName, &credentialsOptions, signer, signatureAlgorithm)
http.HandleFunc(TOKEN_RESOURCE_PATH, putTokenHandler)
http.HandleFunc(SECURITY_CREDENTIALS_RESOURCE_PATH, getRoleNameHandler)
http.HandleFunc(SECURITY_CREDENTIALS_RESOURCE_PATH+roleName, getCredentialsHandler)
// Background thread that cleans up expired tokens
ticker := time.NewTicker(5 * time.Second)
go func() {
for range ticker.C {
curTime := time.Now()
mutex.Lock()
for key, value := range tokenMap {
if curTime.After(value) {
delete(tokenMap, key)
log.Printf("removed expired token: %s", key)
}
}
mutex.Unlock()
}
}()
// Start the credentials endpoint
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", LocalHostAddress, endpoint.PortNum))
if err != nil {
log.Println("failed to create listener")
os.Exit(1)
}
endpoint.PortNum = listener.Addr().(*net.TCPAddr).Port
log.Println("Local server started on port:", endpoint.PortNum)
log.Println("Make it available to the sdk by running:")
log.Printf("export AWS_EC2_METADATA_SERVICE_ENDPOINT=http://%s:%d/", LocalHostAddress, endpoint.PortNum)
if err := endpoint.Server.Serve(listener); err != nil {
log.Println("Httpserver: ListenAndServe() error")
os.Exit(1)
}
}

View File

@@ -0,0 +1,738 @@
package aws_signing_helper
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"golang.org/x/crypto/pkcs12"
)
type SignerParams struct {
OverriddenDate time.Time
RegionName string
ServiceName string
SigningAlgorithm string
}
type CertIdentifier struct {
Subject string
Issuer string
SerialNumber *big.Int
SystemStoreName string // Only relevant in the case of Windows
}
var (
// ErrUnsupportedHash is returned by Signer.Sign() when the provided hash
// algorithm isn't supported.
ErrUnsupportedHash = errors.New("unsupported hash algorithm")
// Predefined system store names.
// See: https://learn.microsoft.com/en-us/windows/win32/seccrypto/system-store-locations
SystemStoreNames = []string{
"MY",
"Root",
"Trust",
"CA",
}
)
// Interface that all signers will have to implement
// (as a result, they will also implement crypto.Signer)
type Signer interface {
Public() crypto.PublicKey
Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error)
Certificate() (certificate *x509.Certificate, err error)
CertificateChain() (certificateChain []*x509.Certificate, err error)
Close()
}
// Container for certificate data returned to the SDK as JSON.
type CertificateData struct {
// Type for the key contained in the certificate.
// Passed back to the `sign-string` command
KeyType string `json:"keyType"`
// Certificate, as base64-encoded DER; used in the `x-amz-x509`
// header in the API request.
CertificateData string `json:"certificateData"`
// Serial number of the certificate. Used in the credential
// field of the Authorization header
SerialNumber string `json:"serialNumber"`
// Supported signing algorithms based on the KeyType
Algorithms []string `json:"supportedAlgorithms"`
}
// Container that adheres to the format of credential_process output as specified by AWS.
type CredentialProcessOutput struct {
// This field should be hard-coded to 1 for now.
Version int `json:"Version"`
// AWS Access Key ID
AccessKeyId string `json:"AccessKeyId"`
// AWS Secret Access Key
SecretAccessKey string `json:"SecretAccessKey"`
// AWS Session Token for temporary credentials
SessionToken string `json:"SessionToken"`
// ISO8601 timestamp for when the credentials expire
Expiration string `json:"Expiration"`
}
type CertificateContainer struct {
// Certificate data
Cert *x509.Certificate
// Certificate URI (only populated in the case that the certificate is a PKCS#11 object)
Uri string
}
// Define constants used in signing
const (
aws4_x509_rsa_sha256 = "AWS4-X509-RSA-SHA256"
aws4_x509_ecdsa_sha256 = "AWS4-X509-ECDSA-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
x_amz_date = "X-Amz-Date"
x_amz_x509 = "X-Amz-X509"
x_amz_x509_chain = "X-Amz-X509-Chain"
x_amz_content_sha256 = "X-Amz-Content-Sha256"
authorization = "Authorization"
host = "Host"
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
)
// Headers that aren't included in calculating the signature
var ignoredHeaderKeys = map[string]bool{
"Authorization": true,
"User-Agent": true,
"X-Amzn-Trace-Id": true,
}
var Debug bool = false
// Find whether the current certificate matches the CertIdentifier
func certMatches(certIdentifier CertIdentifier, cert x509.Certificate) bool {
if certIdentifier.Subject != "" && certIdentifier.Subject != cert.Subject.String() {
return false
}
if certIdentifier.Issuer != "" && certIdentifier.Issuer != cert.Issuer.String() {
return false
}
if certIdentifier.SerialNumber != nil && certIdentifier.SerialNumber.Cmp(cert.SerialNumber) != 0 {
return false
}
return true
}
// Because of *course* we have to do this for ourselves.
//
// Create the DER-encoded SEQUENCE containing R and S:
//
// Ecdsa-Sig-Value ::= SEQUENCE {
// r INTEGER,
// s INTEGER
// }
//
// This is defined in RFC3279 §2.2.3 as well as SEC.1.
// I can't find anything which mandates DER but I've seen
// OpenSSL refusing to verify it with indeterminate length.
func encodeEcdsaSigValue(signature []byte) (out []byte, err error) {
sigLen := len(signature) / 2
return asn1.Marshal(struct {
R *big.Int
S *big.Int
}{
big.NewInt(0).SetBytes(signature[:sigLen]),
big.NewInt(0).SetBytes(signature[sigLen:])})
}
// GetSigner gets the Signer based on the flags passed in by the user (from which the CredentialsOpts structure is derived)
func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string, err error) {
var (
certificate *x509.Certificate
)
privateKeyId := opts.PrivateKeyId
if privateKeyId == "" {
if opts.CertificateId == "" {
if Debug {
log.Println("attempting to use CertStoreSigner")
}
}
privateKeyId = opts.CertificateId
}
if opts.CertificateId != "" && !strings.HasPrefix(opts.CertificateId, "pkcs11:") {
_, cert, err := ReadCertificateData(opts.CertificateId)
if err == nil {
certificate = cert
} else if opts.PrivateKeyId == "" {
if Debug {
log.Println("not a PEM certificate, so trying PKCS#12")
}
if opts.CertificateBundleId != "" {
return nil, "", errors.New("can't specify certificate chain when" +
" using PKCS#12 files; certificate bundle should be provided" +
" within the PKCS#12 file")
}
// Not a PEM certificate? Try PKCS#12
_, _, err = ReadPKCS12Data(opts.CertificateId)
if err != nil {
return nil, "", err
}
return GetFileSystemSigner(opts.PrivateKeyId, opts.CertificateId, opts.CertificateBundleId, true)
} else {
return nil, "", err
}
}
if opts.CertificateBundleId != "" {
if err != nil {
return nil, "", err
}
}
if strings.HasPrefix(privateKeyId, "pkcs11:") {
if Debug {
log.Println("attempting to use PKCS11Signer")
}
if certificate != nil {
opts.CertificateId = ""
}
} else {
_, err = ReadPrivateKeyData(privateKeyId)
if err != nil {
return nil, "", err
}
if certificate == nil {
return nil, "", errors.New("undefined certificate value")
}
if Debug {
log.Println("attempting to use FileSystemSigner")
}
return GetFileSystemSigner(privateKeyId, opts.CertificateId, opts.CertificateBundleId, false)
}
return nil, "", errors.New("unknown certificate type")
}
// Obtain the date-time, formatted as specified by SigV4
func (signerParams *SignerParams) GetFormattedSigningDateTime() string {
return signerParams.OverriddenDate.UTC().Format(timeFormat)
}
// Obtain the short date-time, formatted as specified by SigV4
func (signerParams *SignerParams) GetFormattedShortSigningDateTime() string {
return signerParams.OverriddenDate.UTC().Format(shortTimeFormat)
}
// Obtain the scope as part of the SigV4-X509 signature
func (signerParams *SignerParams) GetScope() string {
var scopeStringBuilder strings.Builder
scopeStringBuilder.WriteString(signerParams.GetFormattedShortSigningDateTime())
scopeStringBuilder.WriteString("/")
scopeStringBuilder.WriteString(signerParams.RegionName)
scopeStringBuilder.WriteString("/")
scopeStringBuilder.WriteString(signerParams.ServiceName)
scopeStringBuilder.WriteString("/")
scopeStringBuilder.WriteString("aws4_request")
return scopeStringBuilder.String()
}
// Convert certificate to string, so that it can be present in the HTTP request header
func certificateToString(certificate *x509.Certificate) string {
return base64.StdEncoding.EncodeToString(certificate.Raw)
}
// Convert certificate chain to string, so that it can be pressent in the HTTP request header
func certificateChainToString(certificateChain []*x509.Certificate) string {
var x509ChainString strings.Builder
for i, certificate := range certificateChain {
x509ChainString.WriteString(certificateToString(certificate))
if i != len(certificateChain)-1 {
x509ChainString.WriteString(",")
}
}
return x509ChainString.String()
}
func CreateRequestSignFunction(signer crypto.Signer, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate) func(*request.Request) {
return func(req *request.Request) {
region := req.ClientInfo.SigningRegion
if region == "" {
region = aws.StringValue(req.Config.Region)
}
name := req.ClientInfo.SigningName
if name == "" {
name = req.ClientInfo.ServiceName
}
signerParams := SignerParams{time.Now(), region, name, signingAlgorithm}
// Set headers that are necessary for signing
req.HTTPRequest.Header.Set(host, req.HTTPRequest.URL.Host)
req.HTTPRequest.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime())
req.HTTPRequest.Header.Set(x_amz_x509, certificateToString(certificate))
if certificateChain != nil {
req.HTTPRequest.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain))
}
contentSha256 := calculateContentHash(req.HTTPRequest, req.Body)
if req.HTTPRequest.Header.Get(x_amz_content_sha256) == "required" {
req.HTTPRequest.Header.Set(x_amz_content_sha256, contentSha256)
}
canonicalRequest, signedHeadersString := createCanonicalRequest(req.HTTPRequest, req.Body, contentSha256)
stringToSign := CreateStringToSign(canonicalRequest, signerParams)
signatureBytes, err := signer.Sign(rand.Reader, []byte(stringToSign), crypto.SHA256)
if err != nil {
log.Println(err.Error())
os.Exit(1)
}
signature := hex.EncodeToString(signatureBytes)
req.HTTPRequest.Header.Set(authorization, BuildAuthorizationHeader(req.HTTPRequest, req.Body, signedHeadersString, signature, certificate, signerParams))
req.SignedHeaderVals = req.HTTPRequest.Header
}
}
// Find the SHA256 hash of the provided request body as a io.ReadSeeker
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
io.Copy(hash, reader)
return hash.Sum(nil)
}
// Calculate the hash of the request body
func calculateContentHash(r *http.Request, body io.ReadSeeker) string {
hash := r.Header.Get(x_amz_content_sha256)
if hash == "" {
if body == nil {
hash = emptyStringSHA256
} else {
hash = hex.EncodeToString(makeSha256Reader(body))
}
}
return hash
}
// Create the canonical query string.
func createCanonicalQueryString(r *http.Request, body io.ReadSeeker) string {
rawQuery := strings.Replace(r.URL.Query().Encode(), "+", "%20", -1)
return rawQuery
}
// Create the canonical header string.
func createCanonicalHeaderString(r *http.Request) (string, string) {
var headers []string
signedHeaderVals := make(http.Header)
for k, v := range r.Header {
canonicalKey := http.CanonicalHeaderKey(k)
if ignoredHeaderKeys[canonicalKey] {
continue
}
lowerCaseKey := strings.ToLower(k)
if _, ok := signedHeaderVals[lowerCaseKey]; ok {
// include additional values
signedHeaderVals[lowerCaseKey] = append(signedHeaderVals[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
signedHeaderVals[lowerCaseKey] = v
}
sort.Strings(headers)
headerValues := make([]string, len(headers))
for i, k := range headers {
headerValues[i] = k + ":" + strings.Join(signedHeaderVals[k], ",")
}
stripExcessSpaces(headerValues)
return strings.Join(headerValues, "\n"), strings.Join(headers, ";")
}
const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not
// contain muliple side-by-side spaces.
func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int
for i, str := range vals {
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
vals[i] = str
continue
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
vals[i] = string(buf[:m])
}
}
// Create the canonical request.
func createCanonicalRequest(r *http.Request, body io.ReadSeeker, contentSha256 string) (string, string) {
var canonicalRequestStrBuilder strings.Builder
canonicalHeaderString, signedHeadersString := createCanonicalHeaderString(r)
canonicalRequestStrBuilder.WriteString("POST")
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString("/sessions")
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString(createCanonicalQueryString(r, body))
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString(canonicalHeaderString)
canonicalRequestStrBuilder.WriteString("\n\n")
canonicalRequestStrBuilder.WriteString(signedHeadersString)
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString(contentSha256)
canonicalRequestString := canonicalRequestStrBuilder.String()
canonicalRequestStringHashBytes := sha256.Sum256([]byte(canonicalRequestString))
return hex.EncodeToString(canonicalRequestStringHashBytes[:]), signedHeadersString
}
// Create the string to sign.
func CreateStringToSign(canonicalRequest string, signerParams SignerParams) string {
var stringToSignStrBuilder strings.Builder
stringToSignStrBuilder.WriteString(signerParams.SigningAlgorithm)
stringToSignStrBuilder.WriteString("\n")
stringToSignStrBuilder.WriteString(signerParams.GetFormattedSigningDateTime())
stringToSignStrBuilder.WriteString("\n")
stringToSignStrBuilder.WriteString(signerParams.GetScope())
stringToSignStrBuilder.WriteString("\n")
stringToSignStrBuilder.WriteString(canonicalRequest)
stringToSign := stringToSignStrBuilder.String()
return stringToSign
}
// Builds the complete authorization header
func BuildAuthorizationHeader(request *http.Request, body io.ReadSeeker, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string {
signingCredentials := certificate.SerialNumber.String() + "/" + signerParams.GetScope()
credential := "Credential=" + signingCredentials
signerHeaders := "SignedHeaders=" + signedHeadersString
signatureHeader := "Signature=" + signature
var authHeaderStringBuilder strings.Builder
authHeaderStringBuilder.WriteString(signerParams.SigningAlgorithm)
authHeaderStringBuilder.WriteString(" ")
authHeaderStringBuilder.WriteString(credential)
authHeaderStringBuilder.WriteString(", ")
authHeaderStringBuilder.WriteString(signerHeaders)
authHeaderStringBuilder.WriteString(", ")
authHeaderStringBuilder.WriteString(signatureHeader)
authHeaderString := authHeaderStringBuilder.String()
return authHeaderString
}
func encodeDer(der []byte) (string, error) {
var buf bytes.Buffer
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
encoder.Write(der)
encoder.Close()
return buf.String(), nil
}
func parseDERFromPEM(pemDataId string, blockType string) (*pem.Block, error) {
b := []byte(pemDataId)
var block *pem.Block
for len(b) > 0 {
block, b = pem.Decode(b)
if block == nil {
return nil, errors.New("unable to parse PEM data")
}
if block.Type == blockType {
return block, nil
}
}
return nil, errors.New("requested block type could not be found")
}
func ReadCertificateBundleData(certificateBundleId string) ([]*x509.Certificate, error) {
bytes, err := os.ReadFile(certificateBundleId)
if err != nil {
log.Println(err)
return nil, err
}
var derBytes []byte
var block *pem.Block
for len(bytes) > 0 {
block, bytes = pem.Decode(bytes)
if block == nil {
break
}
if block.Type != "CERTIFICATE" {
return nil, errors.New("invalid certificate chain")
}
blockBytes := block.Bytes
derBytes = append(derBytes, blockBytes...)
}
return x509.ParseCertificates(derBytes)
}
func readECPrivateKey(privateKeyId string) (ecdsa.PrivateKey, error) {
block, err := parseDERFromPEM(privateKeyId, "EC PRIVATE KEY")
if err != nil {
return ecdsa.PrivateKey{}, errors.New("could not parse PEM data")
}
privateKey, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return ecdsa.PrivateKey{}, errors.New("could not parse private key")
}
return *privateKey, nil
}
func readRSAPrivateKey(privateKeyId string) (rsa.PrivateKey, error) {
block, err := parseDERFromPEM(privateKeyId, "RSA PRIVATE KEY")
if err != nil {
return rsa.PrivateKey{}, errors.New("could not parse PEM data")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return rsa.PrivateKey{}, errors.New("could not parse private key")
}
return *privateKey, nil
}
func readPKCS8PrivateKey(privateKeyId string) (crypto.PrivateKey, error) {
block, err := parseDERFromPEM(privateKeyId, "PRIVATE KEY")
if err != nil {
return nil, errors.New("could not parse PEM data")
}
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, errors.New("could not parse private key")
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if ok {
return *rsaPrivateKey, nil
}
ecPrivateKey, ok := privateKey.(*ecdsa.PrivateKey)
if ok {
return *ecPrivateKey, nil
}
return nil, errors.New("could not parse PKCS#8 private key")
}
// Reads and parses a PKCS#12 file (which should contain an end-entity
// certificate, (optional) certificate chain, and the key associated with the
// end-entity certificate). The end-entity certificate will be the first
// certificate in the returned chain. This method assumes that there is
// exactly one certificate that doesn't issue any others within the container
// and treats that as the end-entity certificate. Also, the order of the other
// certificates in the chain aren't guaranteed (it's also not guaranteed that
// those certificates form a chain with the end-entity certificat either).
func ReadPKCS12Data(certificateId string) (certChain []*x509.Certificate, privateKey crypto.PrivateKey, err error) {
var (
bytes []byte
pemBlocks []*pem.Block
parsedCerts []*x509.Certificate
certMap map[string]*x509.Certificate
endEntityFoundIndex int
)
bytes, err = os.ReadFile(certificateId)
if err != nil {
return nil, nil, nil
}
pemBlocks, err = pkcs12.ToPEM(bytes, "")
if err != nil {
return nil, "", err
}
for _, block := range pemBlocks {
cert, err := x509.ParseCertificate(block.Bytes)
if err == nil {
parsedCerts = append(parsedCerts, cert)
continue
}
privateKeyTmp, err := ReadPrivateKeyDataFromPEMBlock(block)
if err == nil {
privateKey = privateKeyTmp
continue
}
// If neither a certificate nor a private key could be parsed from the
// Block, ignore it and continue.
if Debug {
log.Println("unable to parse PEM block in PKCS#12 file - skipping")
}
}
certMap = make(map[string]*x509.Certificate)
for _, cert := range parsedCerts {
// pkix.Name.String() roughly following the RFC 2253 Distinguished Names
// syntax, so we assume that it's canonical.
issuer := cert.Issuer.String()
certMap[issuer] = cert
}
endEntityFoundIndex = -1
for i, cert := range parsedCerts {
subject := cert.Subject.String()
if _, ok := certMap[subject]; !ok {
certChain = append(certChain, cert)
endEntityFoundIndex = i
break
}
}
if endEntityFoundIndex == -1 {
return nil, "", errors.New("no end-entity certificate found in PKCS#12 file")
}
for i, cert := range parsedCerts {
if i != endEntityFoundIndex {
certChain = append(certChain, cert)
}
}
return certChain, privateKey, nil
}
// Load the private key referenced by `privateKeyId`.
func ReadPrivateKeyData(privateKeyId string) (crypto.PrivateKey, error) {
if key, err := readPKCS8PrivateKey(privateKeyId); err == nil {
return key, nil
}
if key, err := readECPrivateKey(privateKeyId); err == nil {
return key, nil
}
if key, err := readRSAPrivateKey(privateKeyId); err == nil {
return key, nil
}
return nil, errors.New("unable to parse private key")
}
// Reads private key data from a *pem.Block.
func ReadPrivateKeyDataFromPEMBlock(block *pem.Block) (key crypto.PrivateKey, err error) {
key, err = x509.ParseECPrivateKey(block.Bytes)
if err == nil {
return key, nil
}
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err == nil {
return key, nil
}
return nil, errors.New("unable to parse private key")
}
// ReadCertificateData loads the certificate referenced by `certificateId` and extracts
// details required by the SDK to construct the StringToSign.
func ReadCertificateData(certificateId string) (CertificateData, *x509.Certificate, error) {
block, err := parseDERFromPEM(certificateId, "CERTIFICATE")
if err != nil {
return CertificateData{}, nil, errors.New("could not parse PEM data")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Println("could not parse certificate", err)
return CertificateData{}, nil, errors.New("could not parse certificate")
}
//extract serial number
serialNumber := cert.SerialNumber.String()
//encode certificate
encodedDer, _ := encodeDer(block.Bytes)
//extract key type
var keyType string
switch cert.PublicKeyAlgorithm {
case x509.RSA:
keyType = "RSA"
case x509.ECDSA:
keyType = "EC"
default:
keyType = ""
}
supportedAlgorithms := []string{
fmt.Sprintf("%sSHA256", keyType),
fmt.Sprintf("%sSHA384", keyType),
fmt.Sprintf("%sSHA512", keyType),
}
//return struct
return CertificateData{keyType, encodedDer, serialNumber, supportedAlgorithms}, cert, nil
}
// GetCertChain reads a certificate bundle and returns a chain of all the certificates it contains
func GetCertChain(certificateBundleId string) ([]*x509.Certificate, error) {
certificateChainPointers, err := ReadCertificateBundleData(certificateBundleId)
var chain []*x509.Certificate
if err != nil {
return nil, err
}
for _, certificate := range certificateChainPointers {
chain = append(chain, certificate)
}
return chain, nil
}

View File

@@ -0,0 +1,913 @@
package aws_signing_helper
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"io/ioutil"
"log"
"math/big"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"strings"
"testing"
"time"
"unicode/utf8"
"github.com/aws/aws-sdk-go/aws/request"
)
const TestCredentialsFilePath = "/tmp/credentials"
func setup() error {
generateCredentialProcessDataScript := exec.Command("/bin/bash", "../generate-credential-process-data.sh")
_, err := generateCredentialProcessDataScript.Output()
return err
}
func TestMain(m *testing.M) {
err := setup()
if err != nil {
log.Println(err.Error())
os.Exit(1)
}
code := m.Run()
os.Exit(code)
}
// Simple struct to define fixtures
type CertData struct {
CertPath string
KeyType string
}
// Certificate fixtures should be generated by the script ./generate-certs.sh
// if they do not exist, or need to be updated.
func TestReadCertificateData(t *testing.T) {
fixtures := []CertData{
{"../tst/certs/ec-prime256v1-sha256-cert.pem", "EC"},
{"../tst/certs/rsa-2048-sha256-cert.pem", "RSA"},
}
for _, fixture := range fixtures {
certData, _, err := ReadCertificateData(fixture.CertPath)
if err != nil {
t.Log("Failed to read certificate data")
t.Fail()
}
if certData.KeyType != fixture.KeyType {
t.Logf("Wrong key type. Expected %s, got %s", fixture.KeyType, certData.KeyType)
t.Fail()
}
}
}
func TestReadInvalidCertificateData(t *testing.T) {
_, _, err := ReadCertificateData("../tst/certs/invalid-rsa-cert.pem")
if err == nil || !strings.Contains(err.Error(), "could not parse certificate") {
t.Log("Failed to throw a handled error")
t.Fail()
}
}
func TestReadCertificateBundleData(t *testing.T) {
fixtures := []string{
"../tst/certs/cert-bundle.pem",
"../tst/certs/cert-bundle-with-comments.pem",
}
for _, fixture := range fixtures {
_, err := ReadCertificateBundleData(fixture)
if err != nil {
t.Log("Failed to read certificate bundle data")
t.Fail()
}
}
}
func TestReadPrivateKeyData(t *testing.T) {
fixtures := []string{
"../tst/certs/ec-prime256v1-key.pem",
"../tst/certs/ec-prime256v1-key-pkcs8.pem",
"../tst/certs/rsa-2048-key.pem",
"../tst/certs/rsa-2048-key-pkcs8.pem",
}
for _, fixture := range fixtures {
_, err := ReadPrivateKeyData(fixture)
if err != nil {
t.Log(fixture)
t.Log(err)
t.Log("Failed to read private key data")
t.Fail()
}
}
}
func TestReadInvalidPrivateKeyData(t *testing.T) {
_, err := ReadPrivateKeyData("../tst/certs/invalid-rsa-key.pem")
if err == nil || !strings.Contains(err.Error(), "unable to parse private key") {
t.Log("Failed to throw a handled error")
t.Fail()
}
}
func TestBuildAuthorizationHeader(t *testing.T) {
testRequest, err := http.NewRequest("POST", "https://rolesanywhere.us-west-2.amazonaws.com", nil)
if err != nil {
t.Log(err)
t.Fail()
}
path := "../tst/certs/rsa-2048-sha256-cert.pem"
certificateList1, _ := ReadCertificateBundleData(path)
certificate1 := certificateList1[0]
pkPath := "../tst/certs/rsa-2048-key.pem"
awsRequest := request.Request{HTTPRequest: testRequest}
signer, signingAlgorithm, err := GetFileSystemSigner(pkPath, "", path, false)
if err != nil {
t.Log(err)
t.Fail()
}
certificate, err := signer.Certificate()
if err != nil {
t.Log(err)
t.Fail()
}
if !bytes.Equal(certificate.Raw, certificate1.Raw) {
t.Log("Certificate does not match signer certificate")
t.Fail()
}
certificateChain, err := signer.CertificateChain()
if err != nil {
t.Log(err)
t.Fail()
}
for i, cert := range certificateChain {
if !bytes.Equal(cert.Raw, certificateList1[i].Raw) {
t.Log("Certificate chain does not match signer certificate chain")
t.Fail()
}
}
requestSignFunction := CreateRequestSignFunction(signer, signingAlgorithm, certificate, certificateChain)
requestSignFunction(&awsRequest)
certificateList2, _ := ReadCertificateBundleData("../tst/certs/rsa-2048-2-sha256-cert.pem")
certificate2 := certificateList2[0]
os.Rename("../tst/certs/rsa-2048-sha256-cert.pem", "../tst/certs/rsa-2048-sha256-cert.pem.bak")
os.Rename("../tst/certs/rsa-2048-2-sha256-cert.pem", "../tst/certs/rsa-2048-sha256-cert.pem")
certificate, err = signer.Certificate()
if err != nil {
t.Log(err)
t.Fail()
}
if !bytes.Equal(certificate.Raw, certificate2.Raw) {
t.Log("Certificate does not match signer certificate after update")
t.Fail()
}
certificateChain, err = signer.CertificateChain()
if err != nil {
t.Log(err)
t.Fail()
}
for i, cert := range certificateChain {
if !bytes.Equal(cert.Raw, certificateList2[i].Raw) {
t.Log("Certificate chain does not match signer certificate chain after update")
t.Fail()
}
}
os.Rename("../tst/certs/rsa-2048-sha256-cert.pem", "../tst/certs/rsa-2048-2-sha256-cert.pem")
os.Rename("../tst/certs/rsa-2048-sha256-cert.pem.bak", "../tst/certs/rsa-2048-sha256-cert.pem")
requestSignFunction2 := CreateRequestSignFunction(signer, signingAlgorithm, certificate, certificateChain)
requestSignFunction2(&awsRequest)
}
// Verify that the provided payload was signed correctly with the provided options.
// This function is specifically used for unit testing.
func Verify(payload []byte, publicKey crypto.PublicKey, digest crypto.Hash, sig []byte) (bool, error) {
var hash []byte
switch digest {
case crypto.SHA256:
sum := sha256.Sum256(payload)
hash = sum[:]
case crypto.SHA384:
sum := sha512.Sum384(payload)
hash = sum[:]
case crypto.SHA512:
sum := sha512.Sum512(payload)
hash = sum[:]
default:
log.Fatal("unsupported digest")
return false, errors.New("unsupported digest")
}
{
publicKey, ok := publicKey.(*ecdsa.PublicKey)
if ok {
valid := ecdsa.VerifyASN1(publicKey, hash, sig)
return valid, nil
}
}
{
publicKey, ok := publicKey.(*rsa.PublicKey)
if ok {
err := rsa.VerifyPKCS1v15(publicKey, digest, hash, sig)
return err == nil, nil
}
}
return false, nil
}
func TestSign(t *testing.T) {
msg := "test message"
testTable := []CredentialsOpts{}
// TODO: Include tests for PKCS#12 containers, once fixtures are created
// with end-entity certificates.
ec_digests := []string{"sha1", "sha256", "sha384", "sha512"}
ec_curves := []string{"prime256v1", "secp384r1"}
for _, digest := range ec_digests {
for _, curve := range ec_curves {
cert := fmt.Sprintf("../tst/certs/ec-%s-%s-cert.pem",
curve, digest)
key := fmt.Sprintf("../tst/certs/ec-%s-key.pem", curve)
testTable = append(testTable, CredentialsOpts{
CertificateId: cert,
PrivateKeyId: key,
})
key = fmt.Sprintf("../tst/certs/ec-%s-key-pkcs8.pem", curve)
testTable = append(testTable, CredentialsOpts{
CertificateId: cert,
PrivateKeyId: key,
})
}
}
rsa_digests := []string{"md5", "sha1", "sha256", "sha384", "sha512"}
rsa_key_lengths := []string{"1024", "2048", "4096"}
for _, digest := range rsa_digests {
for _, keylen := range rsa_key_lengths {
cert := fmt.Sprintf("../tst/certs/rsa-%s-%s-cert.pem",
keylen, digest)
key := fmt.Sprintf("../tst/certs/rsa-%s-key.pem", keylen)
testTable = append(testTable, CredentialsOpts{
CertificateId: cert,
PrivateKeyId: key,
})
key = fmt.Sprintf("../tst/certs/rsa-%s-key-pkcs8.pem", keylen)
testTable = append(testTable, CredentialsOpts{
CertificateId: cert,
PrivateKeyId: key,
})
}
}
pkcs11_objects := []string{"rsa-2048", "ec-prime256v1"}
for _, object := range pkcs11_objects {
base_pkcs11_uri := "pkcs11:token=credential-helper-test?pin-value=1234"
basic_pkcs11_uri := fmt.Sprintf("pkcs11:token=credential-helper-test;object=%s?pin-value=1234", object)
always_auth_pkcs11_uri := fmt.Sprintf("pkcs11:token=credential-helper-test;object=%s-always-auth?pin-value=1234", object)
cert_file := fmt.Sprintf("../tst/certs/%s-sha256-cert.pem", object)
testTable = append(testTable, CredentialsOpts{
CertificateId: basic_pkcs11_uri,
})
testTable = append(testTable, CredentialsOpts{
PrivateKeyId: basic_pkcs11_uri,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: basic_pkcs11_uri,
PrivateKeyId: basic_pkcs11_uri,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: cert_file,
PrivateKeyId: basic_pkcs11_uri,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: basic_pkcs11_uri,
PrivateKeyId: always_auth_pkcs11_uri,
ReusePin: true,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: cert_file,
PrivateKeyId: always_auth_pkcs11_uri,
ReusePin: true,
})
// Note that for the below test case, there are two matching keys.
// Both keys will validate with the certificate, and one will be chosen
// (it doesn't matter which, since both are the exact same key - it's
// just that one has the CKA_ALWAYS_AUTHENTICATE attribute set).
testTable = append(testTable, CredentialsOpts{
CertificateId: cert_file,
PrivateKeyId: base_pkcs11_uri,
ReusePin: true,
})
}
digestList := []crypto.Hash{crypto.SHA256, crypto.SHA384, crypto.SHA512}
for _, credOpts := range testTable {
signer, _, err := GetSigner(&credOpts)
if err != nil {
var logMsg string
if credOpts.CertificateId != "" || credOpts.PrivateKeyId != "" {
logMsg = fmt.Sprintf("Failed to get signer for '%s'/'%s'",
credOpts.CertificateId, credOpts.PrivateKeyId)
} else {
logMsg = fmt.Sprintf("Failed to get signer for '%s'",
credOpts.CertIdentifier.Subject)
}
t.Log(logMsg)
t.Fail()
return
}
pubKey := signer.Public()
if credOpts.CertificateId != "" && pubKey == nil {
t.Log(fmt.Sprintf("Signer didn't provide public key for '%s'/'%s'",
credOpts.CertificateId, credOpts.PrivateKeyId))
t.Fail()
return
}
for _, digest := range digestList {
signatureBytes, err := signer.Sign(rand.Reader, []byte(msg), digest)
// Try signing again to make sure that there aren't any issues
// with reopening sessions. Also, in some test cases, signing again
// makes sure that the context-specific PIN was saved.
signer.Sign(rand.Reader, []byte(msg), digest)
if err != nil {
t.Log("Failed to sign the input message")
t.Fail()
return
}
_, err = signer.Sign(rand.Reader, []byte(msg), digest)
if err != nil {
t.Log("Failed second signature on the input message")
t.Fail()
return
}
if pubKey != nil {
valid, _ := Verify([]byte(msg), pubKey, digest, signatureBytes)
if !valid {
t.Log(fmt.Sprintf("Failed to verify the signature for '%s'/'%s'",
credOpts.CertificateId, credOpts.PrivateKeyId))
t.Fail()
return
}
}
}
signer.Close()
}
}
func TestCredentialProcess(t *testing.T) {
testTable := []struct {
name string
server *httptest.Server
}{
{
name: "create-session-server-response",
server: GetMockedCreateSessionResponseServer(),
},
}
for _, tc := range testTable {
credentialsOpts := CredentialsOpts{
PrivateKeyId: "../credential-process-data/client-key.pem",
CertificateId: "../credential-process-data/client-cert.pem",
RoleArn: "arn:aws:iam::000000000000:role/ExampleS3WriteRole",
ProfileArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:profile/41cl0bae-6783-40d4-ab20-65dc5d922e45",
TrustAnchorArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:trust-anchor/41cl0bae-6783-40d4-ab20-65dc5d922e45",
Endpoint: tc.server.URL,
SessionDuration: 900,
}
t.Run(tc.name, func(t *testing.T) {
defer tc.server.Close()
signer, signatureAlgorithm, err := GetSigner(&credentialsOpts)
if err != nil {
t.Log("Failed to get signer")
t.Fail()
return
}
resp, err := GenerateCredentials(&credentialsOpts, signer, signatureAlgorithm)
if err != nil {
t.Log(err)
t.Log("Unable to call credential-process")
t.Fail()
}
if resp.AccessKeyId != "accessKeyId" {
t.Log("Incorrect access key id")
t.Fail()
}
if resp.SecretAccessKey != "secretAccessKey" {
t.Log("Incorrect secret access key")
t.Fail()
}
if resp.SessionToken != "sessionToken" {
t.Log("Incorrect session token")
t.Fail()
}
})
}
}
func TestCertStoreSignerCreationFails(t *testing.T) {
testTable := []CredentialsOpts{}
randomLargeSerial := new(big.Int)
randomLargeSerial.SetString("123456719012345678901234567890", 10)
testTable = append(testTable, CredentialsOpts{
CertIdentifier: CertIdentifier{
Subject: "invalid-subject",
},
})
testTable = append(testTable, CredentialsOpts{
CertIdentifier: CertIdentifier{
Issuer: "invalid-issuer",
},
})
testTable = append(testTable, CredentialsOpts{
CertIdentifier: CertIdentifier{
SerialNumber: randomLargeSerial,
},
})
testTable = append(testTable, CredentialsOpts{
CertIdentifier: CertIdentifier{
Subject: "CN=roles-anywhere-rsa-2048-sha25",
SerialNumber: randomLargeSerial,
},
})
for _, credOpts := range testTable {
_, _, err := GetSigner(&credOpts)
if err == nil {
t.Log("Expected failure when creating certificate store signer, but received none")
t.Fail()
}
}
}
func TestSignerCreationFails(t *testing.T) {
var cert string
testTable := []CredentialsOpts{}
ec_digests := []string{"sha1", "sha256", "sha384", "sha512"}
ec_curves := []string{"prime256v1", "secp384r1"}
for _, digest := range ec_digests {
for _, curve := range ec_curves {
cert = fmt.Sprintf("../tst/certs/ec-%s-%s.p12",
curve, digest)
testTable = append(testTable, CredentialsOpts{
CertificateId: cert,
})
}
}
rsa_digests := []string{"md5", "sha1", "sha256", "sha384", "sha512"}
rsa_key_lengths := []string{"1024", "2048", "4096"}
for _, digest := range rsa_digests {
for _, keylen := range rsa_key_lengths {
cert = fmt.Sprintf("../tst/certs/rsa-%s-%s.p12",
keylen, digest)
testTable = append(testTable, CredentialsOpts{
CertificateId: cert,
})
}
}
for _, credOpts := range testTable {
_, _, err := GetSigner(&credOpts)
// We expect a failure since the certificates in these .p12 files are
// self-signed. When creating a signer, we expect there to be an
// end-entity certificate within the container.
if err == nil {
t.Log("Expected failure when creating PKCS#12 signer, but received none")
t.Fail()
}
}
}
func TestPKCS11SignerCreationFails(t *testing.T) {
testTable := []CredentialsOpts{}
template_uri := "pkcs11:token=credential-helper-test;object=%s?pin-value=1234"
rsa_generic_uri := fmt.Sprintf(template_uri, "rsa-2048")
ec_generic_uri := fmt.Sprintf(template_uri, "ec-prime256v1")
always_auth_rsa_uri := fmt.Sprintf(template_uri, "rsa-2048-always-auth")
always_auth_ec_uri := fmt.Sprintf(template_uri, "ec-prime256v1-always-auth")
testTable = append(testTable, CredentialsOpts{
CertificateId: rsa_generic_uri,
PrivateKeyId: ec_generic_uri,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: ec_generic_uri,
PrivateKeyId: rsa_generic_uri,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: "../tst/certs/ec-prime256v1-sha256-cert.pem",
PrivateKeyId: rsa_generic_uri,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: "../tst/certs/rsa-2048-sha256-cert.pem",
PrivateKeyId: ec_generic_uri,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: rsa_generic_uri,
PrivateKeyId: always_auth_ec_uri,
ReusePin: true,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: ec_generic_uri,
PrivateKeyId: always_auth_rsa_uri,
ReusePin: true,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: "../tst/certs/ec-prime256v1-sha256-cert.pem",
PrivateKeyId: always_auth_rsa_uri,
ReusePin: true,
})
testTable = append(testTable, CredentialsOpts{
CertificateId: "../tst/certs/rsa-2048-sha256-cert.pem",
PrivateKeyId: always_auth_ec_uri,
ReusePin: true,
})
for _, credOpts := range testTable {
_, _, err := GetSigner(&credOpts)
if err == nil {
t.Log("Expected failure when creating PKCS#11 signer, but received none")
t.Fail()
}
}
}
func TestUpdate(t *testing.T) {
testTable := []struct {
name string
server *httptest.Server
inputFileContents string
profile string
expectedFileContents string
}{
{
name: "test-space-separated-keys",
server: GetMockedCreateSessionResponseServer(),
inputFileContents: `test
test
test
[test profile]
aws_access_key_id = test
[test]
aws_secret_access_key = test`,
profile: "test profile",
expectedFileContents: `test
test
test
[test profile]
aws_access_key_id = accessKeyId
aws_secret_access_key = secretAccessKey
aws_session_token = sessionToken
[test]
aws_secret_access_key = test`,
},
{
name: "test-profile-with-other-keys",
server: GetMockedCreateSessionResponseServer(),
inputFileContents: `test
test
test
[test profile]
aws_access_key_id = test
test_key = test
[test]
aws_secret_access_key = test`,
profile: "test profile",
expectedFileContents: `test
test
test
[test profile]
aws_access_key_id = accessKeyId
test_key = test
aws_secret_access_key = secretAccessKey
aws_session_token = sessionToken
[test]
aws_secret_access_key = test`,
},
{
name: "test-commented-profile",
server: GetMockedCreateSessionResponseServer(),
inputFileContents: `test
test
test
# [test profile]
aws_access_key_id = test
[test]
aws_secret_access_key = test`,
profile: "test profile",
expectedFileContents: `test
test
test
# [test profile]
aws_access_key_id = test
[test]
aws_secret_access_key = test
[test profile]
aws_access_key_id = accessKeyId
aws_secret_access_key = secretAccessKey
aws_session_token = sessionToken
`,
},
{
name: "test-profile-does-not-exist",
server: GetMockedCreateSessionResponseServer(),
inputFileContents: `test
test
test
[test]
aws_secret_access_key = test`,
profile: "test profile",
expectedFileContents: `test
test
test
[test]
aws_secret_access_key = test
[test profile]
aws_access_key_id = accessKeyId
aws_secret_access_key = secretAccessKey
aws_session_token = sessionToken
`,
},
{
name: "test-first-word-in-profile-matches",
server: GetMockedCreateSessionResponseServer(),
inputFileContents: `test
test
test
[test profile]
aws_access_key_id = test
[test]
aws_secret_access_key = test`,
profile: "test",
expectedFileContents: `test
test
test
[test profile]
aws_access_key_id = test
[test]
aws_access_key_id = accessKeyId
aws_secret_access_key = secretAccessKey
aws_session_token = sessionToken`,
},
{
name: "test-multiple-profiles-with-same-name",
server: GetMockedCreateSessionResponseServer(),
inputFileContents: `test
test
test
[test]
test_key = test
[test profile]
aws_access_key_id = test
[test]
aws_secret_access_key = test`,
profile: "test",
expectedFileContents: `test
test
test
[test]
test_key = test
aws_access_key_id = accessKeyId
aws_secret_access_key = secretAccessKey
aws_session_token = sessionToken
[test profile]
aws_access_key_id = test
[test]
aws_secret_access_key = test`,
},
}
for _, tc := range testTable {
credentialsOpts := CredentialsOpts{
PrivateKeyId: "../credential-process-data/client-key.pem",
CertificateId: "../credential-process-data/client-cert.pem",
RoleArn: "arn:aws:iam::000000000000:role/ExampleS3WriteRole",
ProfileArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:profile/41cl0bae-6783-40d4-ab20-65dc5d922e45",
TrustAnchorArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:trust-anchor/41cl0bae-6783-40d4-ab20-65dc5d922e45",
Endpoint: tc.server.URL,
SessionDuration: 900,
}
t.Run(tc.name, func(t *testing.T) {
SetupTests()
defer tc.server.Close()
os.Setenv(AwsSharedCredentialsFileEnvVarName, TestCredentialsFilePath)
_, err := GetCredentialsFileContents() // first create the credentials file with the appropriate permissions
if err != nil {
t.Log("unable to create credentials file for testing")
t.Fail()
}
writeOnlyCredentialsFile, err := GetWriteOnlyCredentialsFile() // then obtain a handle to the credentials file to perform write operations
if err != nil {
t.Log("unable to write to credentials file for testing")
t.Fail()
}
defer writeOnlyCredentialsFile.Close()
writeOnlyCredentialsFile.WriteString(tc.inputFileContents)
Update(credentialsOpts, tc.profile, true)
fileByteContents, _ := ioutil.ReadFile(TestCredentialsFilePath)
fileStringContents := trimLastChar(string(fileByteContents))
if fileStringContents != tc.expectedFileContents {
t.Log("unexpected file contents")
t.Fail()
}
})
}
}
func TestUpdateFilePermissions(t *testing.T) {
testTable := []struct {
name string
server *httptest.Server
profile string
expectedFileContents string
}{
{
name: "test-space-separated-keys",
server: GetMockedCreateSessionResponseServer(),
profile: "test profile",
expectedFileContents: `[test profile]
aws_access_key_id = accessKeyId
aws_secret_access_key = secretAccessKey
aws_session_token = sessionToken
`,
},
}
for _, tc := range testTable {
credentialsOpts := CredentialsOpts{
PrivateKeyId: "../credential-process-data/client-key.pem",
CertificateId: "../credential-process-data/client-cert.pem",
RoleArn: "arn:aws:iam::000000000000:role/ExampleS3WriteRole",
ProfileArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:profile/41cl0bae-6783-40d4-ab20-65dc5d922e45",
TrustAnchorArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:trust-anchor/41cl0bae-6783-40d4-ab20-65dc5d922e45",
Endpoint: tc.server.URL,
SessionDuration: 900,
}
t.Run(tc.name, func(t *testing.T) {
SetupTests()
defer tc.server.Close()
os.Setenv(AwsSharedCredentialsFileEnvVarName, TestCredentialsFilePath)
Update(credentialsOpts, tc.profile, true)
fileByteContents, _ := ioutil.ReadFile(TestCredentialsFilePath)
fileStringContents := trimLastChar(string(fileByteContents))
if fileStringContents != tc.expectedFileContents {
t.Log("unexpected file contents")
t.Fail()
}
info, _ := os.Stat(TestCredentialsFilePath)
mode := info.Mode()
if mode != ((1 << 8) | (1 << 7)) {
t.Log("unexpected file mode")
t.Fail()
}
})
}
}
func TestGenerateLongToken(t *testing.T) {
_, err := GenerateToken(150)
if err == nil {
t.Log("token generation should've failed since token size is too large")
t.Fail()
}
}
func TestGenerateToken(t *testing.T) {
token1, err := GenerateToken(100)
if err != nil {
t.Log("unexpected failure in generating token")
t.Fail()
}
token2, err := GenerateToken(100)
if err != nil {
t.Log("unexpected failure in generating token")
t.Fail()
}
if token1 == token2 {
t.Log("expected two randomly generated tokens to be different")
t.Fail()
}
}
func TestStoreValidToken(t *testing.T) {
token, err := GenerateToken(100)
if err != nil {
t.Log("unexpected failure in generating token")
t.Fail()
}
err = InsertToken(token, time.Now().Add(time.Second*time.Duration(100)))
if err != nil {
t.Log("unexpected failure when inserting token")
t.Fail()
}
httpRequest, err := http.NewRequest("GET", "http://127.0.0.1", nil)
if err != nil {
t.Log("unable to create test http request")
t.Fail()
}
httpRequest.Header.Add(EC2_METADATA_TOKEN_HEADER, token)
err = CheckValidToken(nil, httpRequest)
if err != nil {
t.Log("expected previously inserted token to be valid")
t.Fail()
}
}
func Test(t *testing.T) {
httpRequest, err := http.NewRequest("GET", "http://127.0.0.1", nil)
if err != nil {
t.Log("unable to create test http request")
t.Fail()
}
httpRequest.Header.Add("test-header", "test-header-value")
headerNames := [4]string{"Test-Header", "test-header", "TEST-HEADER", "tEST-hEadeR"}
for _, header := range headerNames {
testHeaderValue := httpRequest.Header.Get(header)
if testHeaderValue != "test-header-value" {
t.Log("header name canonicalization not working as expected")
t.Fail()
}
}
}
func SetupTests() {
os.Remove(TestCredentialsFilePath)
}
func trimLastChar(s string) string {
r, size := utf8.DecodeLastRuneInString(s)
if r == utf8.RuneError && (size == 0 || size == 1) {
size = 0
}
return s[:len(s)-size]
}
func GetMockedCreateSessionResponseServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`{
"credentialSet":[
{
"assumedRoleUser": {
"arn": "arn:aws:sts::000000000000:assumed-role/ExampleS3WriteRole",
"assumedRoleId": "assumedRoleId"
},
"credentials":{
"accessKeyId": "accessKeyId",
"expiration": "2022-07-27T04:36:55Z",
"secretAccessKey": "secretAccessKey",
"sessionToken": "sessionToken"
},
"packedPolicySize": 10,
"roleArn": "arn:aws:iam::000000000000:role/ExampleS3WriteRole",
"sourceIdentity": "sourceIdentity"
}
],
"subjectArn": "arn:aws:rolesanywhere:us-east-1:000000000000:subject/41cl0bae-6783-40d4-ab20-65dc5d922e45"
}`))
}))
}

View File

@@ -0,0 +1,206 @@
package aws_signing_helper
import (
"bufio"
"log"
"os"
"path/filepath"
"strings"
"time"
)
const UpdateRefreshTime = time.Minute * time.Duration(5)
const AwsSharedCredentialsFileEnvVarName = "AWS_SHARED_CREDENTIALS_FILE"
const BufferSize = 49152
// Structure to contain a temporary credential
type TemporaryCredential struct {
AccessKeyId string
SecretAccessKey string
SessionToken string
Expiration time.Time
}
// Updates credentials in the credentials file for the specified profile
func Update(credentialsOptions CredentialsOpts, profile string, once bool) {
var refreshableCred = TemporaryCredential{}
var nextRefreshTime time.Time
signer, signatureAlgorithm, err := GetSigner(&credentialsOptions)
if err != nil {
log.Println(err)
os.Exit(1)
}
defer signer.Close()
for {
credentialProcessOutput, err := GenerateCredentials(&credentialsOptions, signer, signatureAlgorithm)
if err != nil {
log.Fatal(err)
}
// Assign credential values
refreshableCred.AccessKeyId = credentialProcessOutput.AccessKeyId
refreshableCred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
refreshableCred.SessionToken = credentialProcessOutput.SessionToken // nosemgrep
refreshableCred.Expiration, _ = time.Parse(time.RFC3339, credentialProcessOutput.Expiration)
if (refreshableCred == TemporaryCredential{}) {
log.Println("no credentials created")
os.Exit(1)
}
// Get credentials file contents
lines, err := GetCredentialsFileContents()
if err != nil {
log.Println("unable to get credentials file contents")
os.Exit(1)
}
// Write to credentials file
err = WriteTo(profile, lines, &refreshableCred)
if err != nil {
log.Println("unable to write to AWS credentials file")
os.Exit(1)
}
if once {
break
}
nextRefreshTime = refreshableCred.Expiration.Add(-UpdateRefreshTime)
log.Println("Credentials will be refreshed at", nextRefreshTime.String())
time.Sleep(time.Until(nextRefreshTime))
}
}
// Assume that the credentials file is located in the default path: `~/.aws/credentials`
func GetCredentialsFileContents() ([]string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
log.Println("unable to locate the home directory")
return nil, err
}
awsCredentialsPath := os.Getenv(AwsSharedCredentialsFileEnvVarName)
if awsCredentialsPath == "" {
awsCredentialsPath = filepath.Join(homeDir, ".aws", "credentials")
}
if err = os.MkdirAll(filepath.Dir(awsCredentialsPath), 0600); err != nil {
log.Println("unable to create credentials file")
return nil, err
}
readOnlyCredentialsFile, err := os.OpenFile(awsCredentialsPath, os.O_RDONLY|os.O_CREATE, 0600)
if err != nil {
log.Println("unable to get or create read-only AWS credentials file")
os.Exit(1)
}
defer readOnlyCredentialsFile.Close()
// Read in all profiles in the credentials file
var lines []string
scanner := bufio.NewScanner(readOnlyCredentialsFile)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
return lines, nil
}
// Assume that the credentials file exists already and open it for write operations
// that will overwrite the existing contents of the file
func GetWriteOnlyCredentialsFile() (*os.File, error) {
homeDir, _ := os.UserHomeDir()
awsCredentialsPath := os.Getenv(AwsSharedCredentialsFileEnvVarName)
if awsCredentialsPath == "" {
awsCredentialsPath = filepath.Join(homeDir, ".aws", "credentials")
}
return os.OpenFile(awsCredentialsPath, os.O_WRONLY|os.O_TRUNC, 0200)
}
// Function that will get the new conents of the credentials file after a
// refresh has been done
func GetNewCredentialsFileContents(profileName string, readLines []string, cred *TemporaryCredential) []string {
var profileExist = false
var profileSection = "[" + profileName + "]"
// A variable that checks whether or not required fields are written to the destination file
newCredVisit := map[string]bool{"aws_access_key_id": false, "aws_secret_access_key": false, "aws_session_token": false}
accessKey := "aws_access_key_id = " + cred.AccessKeyId + "\n"
secretKey := "aws_secret_access_key = " + cred.SecretAccessKey + "\n"
sessionToken := "aws_session_token = " + cred.SessionToken + "\n"
var writeLines = make([]string, 0)
for readLinesIndex := 0; readLinesIndex < len(readLines); readLinesIndex++ {
if !profileExist && readLines[readLinesIndex] == profileSection {
writeLines = append(writeLines[:], profileSection+"\n")
readLinesIndex += 1
for ; readLinesIndex < len(readLines); readLinesIndex++ {
// If the last line of the credentials file is reached
// OR the next profile section is reached
if readLinesIndex == len(readLines)-1 || strings.HasPrefix(readLines[readLinesIndex], "[") {
if !newCredVisit["aws_access_key_id"] {
writeLines = append(writeLines[:], accessKey)
}
if !newCredVisit["aws_secret_access_key"] {
writeLines = append(writeLines[:], secretKey)
}
if !newCredVisit["aws_session_token"] {
writeLines = append(writeLines[:], sessionToken)
}
if readLinesIndex != len(readLines)-1 {
readLinesIndex -= 1
}
profileExist = true
break
} else if strings.HasPrefix(readLines[readLinesIndex], "aws_access_key_id") {
// replace "aws_access_key_id"
writeLines = append(writeLines[:], accessKey)
newCredVisit["aws_access_key_id"] = true
} else if strings.HasPrefix(readLines[readLinesIndex], "aws_secret_access_key") {
// replace "aws_secret_access_key"
writeLines = append(writeLines[:], secretKey)
newCredVisit["aws_secret_access_key"] = true
} else if strings.HasPrefix(readLines[readLinesIndex], "aws_session_token") {
// replace "aws_session_token"
writeLines = append(writeLines[:], sessionToken)
newCredVisit["aws_session_token"] = true
} else {
// write other keys
writeLines = append(writeLines[:], readLines[readLinesIndex]+"\n")
}
}
} else {
writeLines = append(writeLines[:], readLines[readLinesIndex]+"\n")
}
}
// If the chosen profile does not exist
if !profileExist {
writeCredential := profileSection + "\n" + accessKey + secretKey + sessionToken
writeLines = append(writeLines[:], writeCredential+"\n")
}
return writeLines
}
// Function to write existing credentials and newly-created credentials to a destination file
func WriteTo(profileName string, readLines []string, cred *TemporaryCredential) error {
destFile, err := GetWriteOnlyCredentialsFile()
if err != nil {
log.Println("unable to get write-only AWS credentials file")
os.Exit(1)
}
defer destFile.Close()
// Create buffered writer
destFileWriter := bufio.NewWriterSize(destFile, BufferSize)
for _, line := range GetNewCredentialsFileContents(profileName, readLines, cred) {
_, err := destFileWriter.WriteString(line)
if err != nil {
log.Println("unable to write to credentials file")
os.Exit(1)
}
}
// Flush the contents of the buffer
destFileWriter.Flush()
return nil
}

18
cmd/credential_process.go Normal file
View File

@@ -0,0 +1,18 @@
package cmd
import helper "git.s.int/rrise/aws-iam-anywhere-refresher/aws_signing_helper"
func Run(opts *helper.CredentialsOpts) (*helper.CredentialProcessOutput, error) {
signer, signingAlgorithm, err := helper.GetSigner(opts)
if err != nil {
return nil, err
}
defer signer.Close()
credentialProcessOutput, err := helper.GenerateCredentials(opts, signer, signingAlgorithm)
if err != nil {
return nil, err
}
return &credentialProcessOutput, nil
}

57
config/config.go Normal file
View File

@@ -0,0 +1,57 @@
package config
import "git.s.int/packages/go/utilities/Env"
const (
namespace Env.EnvironmentVariable = "NAMESPACE"
secretName Env.EnvironmentVariable = "SECRET"
roleArn Env.EnvironmentVariable = "ROLE_ARN"
profileArn Env.EnvironmentVariable = "PROFILE_ARN"
trustedAnchorArn Env.EnvironmentVariable = "TRUSTED_ANCHOR_ARN"
privateKey Env.EnvironmentVariable = "PRIVATE_KEY"
certificate Env.EnvironmentVariable = "CERTIFICATE"
sessionDuration Env.EnvironmentVariable = "SESSION_DURATION"
restartDeployments Env.EnvironmentVariable = "RESTART_DEPLOYMENTS"
)
type Config struct{}
func NewConfig() *Config {
return &Config{}
}
func (Config) Namespace() string {
return namespace.GetEnvString("")
}
func (Config) Secret() string {
return secretName.GetEnvString("aws-credentials")
}
func (Config) RoleArn() string {
return roleArn.GetEnvString("")
}
func (Config) ProfileArn() string {
return profileArn.GetEnvString("")
}
func (Config) TrustedAnchor() string {
return trustedAnchorArn.GetEnvString("")
}
func (Config) PrivateKey() string {
return privateKey.GetEnvString("")
}
func (Config) Certificate() string {
return certificate.GetEnvString("")
}
func (Config) SessionDuration() int64 {
return sessionDuration.GetEnvInt("SESSION_DURATION", 900)
}
func (Config) RestartDeployments() bool {
return restartDeployments.GetEnvBool(false)
}

51
go.mod Normal file
View File

@@ -0,0 +1,51 @@
module git.s.int/rrise/aws-iam-anywhere-refresher
go 1.22.5
require (
git.s.int/packages/go/utilities v1.2.2
github.com/aws/aws-sdk-go v1.55.5
github.com/aws/rolesanywhere-credential-helper v1.1.1
golang.org/x/crypto v0.21.0
k8s.io/api v0.30.3
k8s.io/apimachinery v0.30.3
k8s.io/client-go v0.30.3
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/swag v0.22.3 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/oauth2 v0.10.0 // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/term v0.22.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.120.1 // indirect
k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 // indirect
k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
)

162
go.sum Normal file
View File

@@ -0,0 +1,162 @@
git.s.int/packages/go/utilities v1.2.2 h1:IXKdrTgRc7tnDUB4sOWD/kjwgw9luUzvsaPzX+Dhm7Y=
git.s.int/packages/go/utilities v1.2.2/go.mod h1:1nIS3PzUaLiNBBkyme408XbI725PiureeTV7iBXfUI0=
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/rolesanywhere-credential-helper v1.1.1 h1:Dmt9VElG4V4PRLr5fXOGIjK72ajf/A9As0bybjOTLH4=
github.com/aws/rolesanywhere-credential-helper v1.1.1/go.mod h1:Rbs7kVBO+dJu26o9+TAeiJLCb4myq77aUy7D5pAO9dg=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE=
github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs=
github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE=
github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k=
github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g=
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I=
github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/onsi/ginkgo/v2 v2.15.0 h1:79HwNRBAZHOEwrczrgSOPy+eFTTlIGELKy5as+ClttY=
github.com/onsi/ginkgo/v2 v2.15.0/go.mod h1:HlxMHtYF57y6Dpf+mc5529KKmSq9h2FpCF+/ZkwUxKM=
github.com/onsi/gomega v1.31.0 h1:54UJxxj6cPInHS3a35wm6BK/F9nHYueZ1NVujHDrnXE=
github.com/onsi/gomega v1.31.0/go.mod h1:DW9aCi7U6Yi40wNVAvT6kzFnEVEI5n3DloYBiKiT6zk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.10.0 h1:zHCpF2Khkwy4mMB4bv0U37YtJdTGW8jI0glAApi0Kh8=
golang.org/x/oauth2 v0.10.0/go.mod h1:kTpgurOux7LqtuxjuyZa4Gj2gdezIt/jQtGnNFfypQI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/api v0.30.3 h1:ImHwK9DCsPA9uoU3rVh4QHAHHK5dTSv1nxJUapx8hoQ=
k8s.io/api v0.30.3/go.mod h1:GPc8jlzoe5JG3pb0KJCSLX5oAFIW3/qNJITlDj8BH04=
k8s.io/apimachinery v0.30.3 h1:q1laaWCmrszyQuSQCfNB8cFgCuDAoPszKY4ucAjDwHc=
k8s.io/apimachinery v0.30.3/go.mod h1:iexa2somDaxdnj7bha06bhb43Zpa6eWH8N8dbqVjTUc=
k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k=
k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U=
k8s.io/klog/v2 v2.120.1 h1:QXU6cPEOIslTGvZaXvFWiP9VKyeet3sawzTOvdXb4Vw=
k8s.io/klog/v2 v2.120.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 h1:BZqlfIlq5YbRMFko6/PM7FjZpUb45WallggurYhKGag=
k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340/go.mod h1:yD4MZYeKMBwQKVht279WycxKyM84kkAx2DPrTXaeb98=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08=
sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo=
sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8=

128
main.go Normal file
View File

@@ -0,0 +1,128 @@
package main
import (
"context"
"encoding/base64"
helper "git.s.int/rrise/aws-iam-anywhere-refresher/aws_signing_helper"
"git.s.int/rrise/aws-iam-anywhere-refresher/cmd"
appConfig "git.s.int/rrise/aws-iam-anywhere-refresher/config"
v1k "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"log"
"os"
"time"
)
func main() {
println("Starting credentials refresh")
config, err := rest.InClusterConfig()
if err != nil {
println("Are you running in a cluster?")
panic(err)
}
client, err := kubernetes.NewForConfig(config)
if err != nil {
panic(err)
}
c := appConfig.NewConfig()
privateKey, err := base64.StdEncoding.DecodeString(c.PrivateKey())
if err != nil {
log.Fatal("error:", err)
}
certificate, err := base64.StdEncoding.DecodeString(c.Certificate())
if err != nil {
log.Fatal("error:", err)
}
credentials, err := cmd.Run(&helper.CredentialsOpts{
PrivateKeyId: string(privateKey),
CertificateId: string(certificate),
CertIdentifier: helper.CertIdentifier{
SystemStoreName: "MY",
},
RoleArn: c.RoleArn(),
ProfileArnStr: c.ProfileArn(),
TrustAnchorArnStr: c.TrustedAnchor(),
SessionDuration: int(c.SessionDuration()),
})
if err != nil {
panic(err)
}
println("Got new credentials")
secret := &v1k.Secret{
ObjectMeta: v1.ObjectMeta{
Name: c.Secret(),
Labels: map[string]string{
"managed-by": "aws-iam-anywhere-refresher",
},
},
StringData: map[string]string{
"AWS_ACCESS_KEY_ID": credentials.AccessKeyId,
"AWS_SECRET_ACCESS_KEY": credentials.SecretAccessKey,
"AWS_SESSION_TOKEN": credentials.SessionToken,
},
}
_, err = client.CoreV1().Secrets(c.Namespace()).Get(context.TODO(), c.Secret(), v1.GetOptions{})
if err != nil {
println(err.Error())
println("secret doesn't exist, trying to create")
create, err := client.CoreV1().Secrets(c.Namespace()).Create(context.Background(), secret, v1.CreateOptions{})
if err != nil {
panic(err)
}
println("secret created")
println(create.CreationTimestamp.String())
} else {
update, err := client.CoreV1().Secrets(c.Namespace()).Update(context.TODO(), secret, v1.UpdateOptions{})
if err != nil {
panic(err)
}
println("secret updated")
println(update.CreationTimestamp.String())
}
if c.RestartDeployments() {
println("Restarting deployments...")
deployments, err := client.AppsV1().Deployments(c.Namespace()).List(context.TODO(), v1.ListOptions{
LabelSelector: "iam-role-type=aws-iam-anywhere",
})
if err != nil {
panic(err)
}
for _, deployment := range deployments.Items {
println("Restarting deployment", deployment.Name)
if deployment.Spec.Template.ObjectMeta.Annotations == nil {
deployment.Spec.Template.ObjectMeta.Annotations = make(map[string]string)
}
deployment.Spec.Template.ObjectMeta.Annotations["kubectl.kubernetes.io/restartedAt"] = time.Now().Format(time.RFC3339)
_, err = client.AppsV1().Deployments(c.Namespace()).Update(context.TODO(), &deployment, v1.UpdateOptions{})
if err != nil {
println(err.Error())
}
}
}
println("Done!")
os.Exit(0)
}