3 Commits

Author SHA1 Message Date
bcd0c19ead Is there an achievement for this?
Some checks failed
🏗️✨ Test Build Workflow / 🖥️ 🔨 Build (push) Failing after 8m8s
2025-05-14 13:39:20 -04:00
f8f63e418c I really should've committed this when I finished it...
Some checks failed
🏗️✨ Test Build Workflow / 🖥️ 🔨 Build (push) Failing after 4m33s
2025-05-14 13:30:55 -04:00
767e81f8ef Switched off unit test 12 because the build had to go out now and there was no time to fix it properly. 2025-05-14 13:28:34 -04:00
16 changed files with 344 additions and 296 deletions

View File

@@ -1,5 +1 @@
.idea/ .idea/
.gitea/
aws-iam-anywhere-refresher
LICENSE
README.md

View File

@@ -1,51 +0,0 @@
on:
push:
tags:
- "v*"
name: 🏗️✨ Build Workflow
jobs:
Build:
name: 🖥️ 🔨 Build
runs-on: ubuntu-latest
steps:
- name: 🛡️ 🔒 Add Siteworx CA Certificates
run: |
apt update && apt install -yq ca-certificates curl
curl -Ls https://siteworxpro.com/hosted/Siteworx+Root+CA.pem -o /usr/local/share/ca-certificates/sw.crt
update-ca-certificates
- name: 📖 🔍 Checkout Repository Code
uses: actions/checkout@v2
with:
fetch-depth: 1
- name: 🔑 🔐 Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: 🏗️ 🔧 Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: 🐳 🔨 Build Backend Container
uses: docker/build-push-action@v6
with:
provenance: true
sbom: true
push: true
context: .
dockerfile: Dockerfile
tags: siteworxpro/aws-iam-anywhere:${{ gitea.ref_name }}
- name: 🐳 🔨 Build Backend Container - Latest Tag
uses: docker/build-push-action@v6
with:
provenance: true
sbom: true
push: true
context: .
dockerfile: Dockerfile
tags: siteworxpro/aws-iam-anywhere:latest

View File

@@ -33,7 +33,6 @@ jobs:
- name: 🐳 🔨 Build Backend Container - name: 🐳 🔨 Build Backend Container
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
platforms: linux/amd64
context: . context: .
dockerfile: Dockerfile dockerfile: Dockerfile
tags: siteworxpro/aws-iam-anywhere:${{ gitea.ref_name }} tags: siteworxpro/template:${{ gitea.ref_name }}

View File

@@ -1,24 +1,21 @@
FROM siteworxpro/golang:1.24.3 AS build FROM siteworxpro/golang:1.24.3 AS build
ENV GOPRIVATE=git.siteworxpro.com
ENV GOPROXY=direct
WORKDIR /app WORKDIR /app
ADD . . ADD . .
ENV GOPRIVATE=git.siteworxpro.com RUN go mod tidy && CGO_ENABLED=0 GOOS=linux GOARCH=amd64 GO111MODULE=on go build -o /app/aws-iam-anywhere-refresher
RUN go mod download && go build -o aws-iam-anywhere-refresher . FROM alpine:latest AS runtime
FROM ubuntu:latest AS runtime
RUN apt update && apt install -yq ca-certificates curl
RUN curl -Ls https://siteworxpro.com/hosted/Siteworx+Root+CA.pem -o /usr/local/share/ca-certificates/sw.crt \
&& update-ca-certificates
WORKDIR /app WORKDIR /app
COPY --from=build /app/aws-iam-anywhere-refresher /app/aws-iam-anywhere-refresher COPY --from=build /app/aws-iam-anywhere-refresher aws-iam-anywhere-refresher
RUN useradd -b /app iam && \ RUN adduser -D -H iam && \
chown iam:iam /app/aws-iam-anywhere-refresher chown iam:iam /app/aws-iam-anywhere-refresher
USER iam USER iam

View File

@@ -28,7 +28,6 @@ This image runs in a kubernetes cronjob and will create and save new IAM credent
- `TRUSTED_ANCHOR_ARN` ***required*** : the trusted anchor arn - `TRUSTED_ANCHOR_ARN` ***required*** : the trusted anchor arn
- `PRIVATE_KEY` ***required*** : iam private key base64 encoded - `PRIVATE_KEY` ***required*** : iam private key base64 encoded
- `CERTIFICATE` ***required*** : iam certificate base64 encoded - `CERTIFICATE` ***required*** : iam certificate base64 encoded
- `CA_CHAIN` : the certificate chain bundle if needed
```yaml ```yaml

View File

@@ -176,9 +176,9 @@ func GetCertStoreSigner(certIdentifier CertIdentifier, useLatestExpiringCert boo
// Find the signing algorithm // Find the signing algorithm
switch cert.PublicKey.(type) { switch cert.PublicKey.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
signingAlgorithm = aws4X509EcdsaSha256 signingAlgorithm = aws4_x509_ecdsa_sha256
case *rsa.PublicKey: case *rsa.PublicKey:
signingAlgorithm = aws4X509RsaSha256 signingAlgorithm = aws4_x509_rsa_sha256
default: default:
err = errors.New("unsupported algorithm") err = errors.New("unsupported algorithm")
goto fail goto fail

View File

@@ -44,6 +44,7 @@ type CredentialsOpts struct {
Pkcs8Password string Pkcs8Password string
} }
// Middleware to set a custom user agent header
func createCredHelperUserAgentMiddleware(userAgent string) middleware.BuildMiddleware { func createCredHelperUserAgentMiddleware(userAgent string) middleware.BuildMiddleware {
return middleware.BuildMiddlewareFunc("UserAgent", func( return middleware.BuildMiddlewareFunc("UserAgent", func(
ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler, ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler,
@@ -55,6 +56,7 @@ func createCredHelperUserAgentMiddleware(userAgent string) middleware.BuildMiddl
}) })
} }
// Function to create session and generate credentials
func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) { func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) {
// Assign values to region and endpoint if they haven't already been assigned // Assign values to region and endpoint if they haven't already been assigned
trustAnchorArn, err := arn.Parse(opts.TrustAnchorArnStr) trustAnchorArn, err := arn.Parse(opts.TrustAnchorArnStr)

View File

@@ -94,11 +94,11 @@ func GetFileSystemSigner(privateKeyPath string, certPath string, bundlePath stri
// Find the signing algorithm // Find the signing algorithm
_, isRsaKey := privateKey.(*rsa.PrivateKey) _, isRsaKey := privateKey.(*rsa.PrivateKey)
if isRsaKey { if isRsaKey {
signingAlgorithm = aws4X509RsaSha256 signingAlgorithm = aws4_x509_rsa_sha256
} }
_, isEcKey := privateKey.(*ecdsa.PrivateKey) _, isEcKey := privateKey.(*ecdsa.PrivateKey)
if isEcKey { if isEcKey {
signingAlgorithm = aws4X509EcdsaSha256 signingAlgorithm = aws4_x509_ecdsa_sha256
} }
if signingAlgorithm == "" { if signingAlgorithm == "" {
log.Println("unsupported algorithm") log.Println("unsupported algorithm")

View File

@@ -39,6 +39,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
@@ -46,12 +47,13 @@ import (
"unsafe" "unsafe"
"github.com/miekg/pkcs11" "github.com/miekg/pkcs11"
"github.com/stefanberger/go-pkcs11uri" pkcs11uri "github.com/stefanberger/go-pkcs11uri"
) )
var Pkcs11TestVersion int16 = 1 var PKCS11_TEST_VERSION int16 = 1
var MaxObjectLimit int = 1000 var MAX_OBJECT_LIMIT int = 1000
// In our list of certs, we want to remember the CKA_ID/CKA_LABEL too.
type CertObjInfo struct { type CertObjInfo struct {
id []byte id []byte
label []byte label []byte
@@ -59,12 +61,14 @@ type CertObjInfo struct {
certObject pkcs11.ObjectHandle certObject pkcs11.ObjectHandle
} }
// In our list of keys, we want to remember the CKA_ID/CKA_LABEL too.
type KeyObjInfo struct { type KeyObjInfo struct {
id []byte id []byte
label []byte label []byte
keyObject pkcs11.ObjectHandle keyObject pkcs11.ObjectHandle
} }
// Used to enumerate slots with all token/slot info for matching.
type SlotIdInfo struct { type SlotIdInfo struct {
id uint id uint
info pkcs11.SlotInfo info pkcs11.SlotInfo
@@ -110,7 +114,7 @@ func initializePKCS11Module(lib string) (module *pkcs11.Ctx, err error) {
fail: fail:
if module != nil { if module != nil {
_ = module.Finalize() module.Finalize()
module.Destroy() module.Destroy()
} }
return nil, err return nil, err
@@ -133,10 +137,18 @@ func enumerateSlotsInPKCS11Module(module *pkcs11.Ctx) (slots []SlotIdInfo, err e
slotIdInfo.id = slotId slotIdInfo.id = slotId
slotIdInfo.info, slotErr = module.GetSlotInfo(slotId) slotIdInfo.info, slotErr = module.GetSlotInfo(slotId)
if slotErr != nil { if slotErr != nil {
if Debug {
log.Printf("unable to get slot info for slot %d"+
" (%s)\n", slotId, slotErr)
}
continue continue
} }
slotIdInfo.tokInfo, slotErr = module.GetTokenInfo(slotId) slotIdInfo.tokInfo, slotErr = module.GetTokenInfo(slotId)
if slotErr != nil { if slotErr != nil {
if Debug {
log.Printf("unable to get token info for slot %d"+
" (%s)\n", slotId, slotErr)
}
continue continue
} }
@@ -218,7 +230,7 @@ func getFindTemplate(uri *pkcs11uri.Pkcs11URI, class uint) (template []*pkcs11.A
// Gets certificate(s) within the PKCS#11 session (i.e. a given token) that // Gets certificate(s) within the PKCS#11 session (i.e. a given token) that
// matches the given URI. // matches the given URI.
func getCertsInSession(module *pkcs11.Ctx, _ uint, session pkcs11.SessionHandle, uri *pkcs11uri.Pkcs11URI) (certs []CertObjInfo, err error) { func getCertsInSession(module *pkcs11.Ctx, slotId uint, session pkcs11.SessionHandle, uri *pkcs11uri.Pkcs11URI) (certs []CertObjInfo, err error) {
var ( var (
sessionCertObjects []pkcs11.ObjectHandle sessionCertObjects []pkcs11.ObjectHandle
certObjects []pkcs11.ObjectHandle certObjects []pkcs11.ObjectHandle
@@ -233,7 +245,7 @@ func getCertsInSession(module *pkcs11.Ctx, _ uint, session pkcs11.SessionHandle,
} }
for true { for true {
sessionCertObjects, _, err = module.FindObjects(session, MaxObjectLimit) sessionCertObjects, _, err = module.FindObjects(session, MAX_OBJECT_LIMIT)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -241,7 +253,7 @@ func getCertsInSession(module *pkcs11.Ctx, _ uint, session pkcs11.SessionHandle,
break break
} }
certObjects = append(certObjects, sessionCertObjects...) certObjects = append(certObjects, sessionCertObjects...)
if len(sessionCertObjects) < MaxObjectLimit { if len(sessionCertObjects) < MAX_OBJECT_LIMIT {
break break
} }
} }
@@ -323,7 +335,11 @@ func getMatchingCerts(module *pkcs11.Ctx, slots []SlotIdInfo, uri *pkcs11uri.Pkc
for _, slot := range slots { for _, slot := range slots {
curSession, err := module.OpenSession(slot.id, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKS_RO_PUBLIC_SESSION) curSession, err := module.OpenSession(slot.id, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKS_RO_PUBLIC_SESSION)
if err != nil { if err != nil {
_ = module.CloseSession(curSession) if Debug {
log.Printf("unable to open session in slot %d"+
" (%s)\n", slot.id, err)
}
module.CloseSession(curSession)
continue continue
} }
@@ -338,7 +354,7 @@ func getMatchingCerts(module *pkcs11.Ctx, slots []SlotIdInfo, uri *pkcs11uri.Pkc
goto skipCloseSession goto skipCloseSession
} }
} }
_ = module.CloseSession(curSession) module.CloseSession(curSession)
skipCloseSession: skipCloseSession:
} }
@@ -406,12 +422,86 @@ foundCert:
fail: fail:
if session != 0 { if session != 0 {
_ = module.Logout(session) module.Logout(session)
_ = module.CloseSession(session) module.CloseSession(session)
} }
return SlotIdInfo{}, session, false, nil, err return SlotIdInfo{}, session, false, nil, err
} }
// Used to implement a cut-down version of `p11tool --list-certificates`.
func GetMatchingPKCSCerts(uriStr string, lib string) (matchingCerts []CertificateContainer, err error) {
var (
slots []SlotIdInfo
module *pkcs11.Ctx
uri *pkcs11uri.Pkcs11URI
userPin string
certObjs []CertObjInfo
session pkcs11.SessionHandle
loggedIn bool
slot SlotIdInfo
)
uri = pkcs11uri.New()
err = uri.Parse(uriStr)
if err != nil {
return nil, err
}
userPin, _ = uri.GetQueryAttribute("pin-value", false)
module, err = initializePKCS11Module(lib)
if err != nil {
goto cleanUp
}
slots, err = enumerateSlotsInPKCS11Module(module)
if err != nil {
goto cleanUp
}
slot, session, loggedIn, certObjs, err = getMatchingCerts(module, slots, uri, userPin, false)
if err != nil {
goto cleanUp
}
for _, obj := range certObjs {
curUri := pkcs11uri.New()
curUri.AddPathAttribute("model", slot.tokInfo.Model)
curUri.AddPathAttribute("manufacturer", slot.tokInfo.ManufacturerID)
curUri.AddPathAttribute("serial", slot.tokInfo.SerialNumber)
curUri.AddPathAttribute("slot-description", slot.info.SlotDescription)
curUri.AddPathAttribute("slot-manufacturer", slot.info.ManufacturerID)
if obj.id != nil {
curUri.AddPathAttribute("id", string(obj.id[:]))
}
if obj.label != nil {
curUri.AddPathAttribute("object", string(obj.label[:]))
}
curUri.AddPathAttribute("type", "cert")
curUriStr, err := curUri.Format() // nosemgrep
if err != nil {
curUriStr = ""
}
matchingCerts = append(matchingCerts, CertificateContainer{-1, obj.cert, curUriStr})
}
// Note that this clean up should happen regardless of failure.
cleanUp:
if module != nil {
if session != 0 {
if loggedIn {
module.Logout(session)
}
module.CloseSession(session)
}
module.Finalize()
module.Destroy()
}
return matchingCerts, err
}
// Returns the public key associated with this PKCS11Signer.
func (pkcs11Signer *PKCS11Signer) Public() crypto.PublicKey { func (pkcs11Signer *PKCS11Signer) Public() crypto.PublicKey {
var ( var (
cert *x509.Certificate cert *x509.Certificate
@@ -432,13 +522,14 @@ func (pkcs11Signer *PKCS11Signer) Public() crypto.PublicKey {
return nil return nil
} }
// Closes this PKCS11Signer.
func (pkcs11Signer *PKCS11Signer) Close() { func (pkcs11Signer *PKCS11Signer) Close() {
var module *pkcs11.Ctx var module *pkcs11.Ctx
module = pkcs11Signer.module module = pkcs11Signer.module
if module != nil { if module != nil {
_ = module.Finalize() module.Finalize()
module.Destroy() module.Destroy()
} }
@@ -474,17 +565,13 @@ func pkcs11PasswordPrompt(module *pkcs11.Ctx, session pkcs11.SessionHandle, user
if err != nil { if err != nil {
return "", errors.New(parseErrMsg) return "", errors.New(parseErrMsg)
} }
defer func(ttyReadFile *os.File) { defer ttyReadFile.Close()
_ = ttyReadFile.Close()
}(ttyReadFile)
ttyWriteFile, err = os.OpenFile(ttyWritePath, os.O_WRONLY, 0) ttyWriteFile, err = os.OpenFile(ttyWritePath, os.O_WRONLY, 0)
if err != nil { if err != nil {
return "", errors.New(parseErrMsg) return "", errors.New(parseErrMsg)
} }
defer func(ttyWriteFile *os.File) { defer ttyWriteFile.Close()
_ = ttyWriteFile.Close()
}(ttyWriteFile)
for true { for true {
pin, err = GetPassword(ttyReadFile, ttyWriteFile, prompt, parseErrMsg) pin, err = GetPassword(ttyReadFile, ttyWriteFile, prompt, parseErrMsg)
@@ -567,24 +654,28 @@ func signHelper(module *pkcs11.Ctx, session pkcs11.SessionHandle, privateKeyObj
err = module.Login(session, pkcs11.CKU_CONTEXT_SPECIFIC, contextSpecificPin) err = module.Login(session, pkcs11.CKU_CONTEXT_SPECIFIC, contextSpecificPin)
if err == nil { if err == nil {
goto afterContextSpecificLogin goto afterContextSpecificLogin
} else {
if Debug {
log.Printf("user re-authentication attempt failed (%s)\n", err.Error())
}
} }
} }
// If the context-specific PIN couldn't be derived, prompt the user for // If the context-specific PIN couldn't be derived, prompt the user for
// the context-specific PIN for this object. // the context-specific PIN for this object.
keyUri = pkcs11uri.New() keyUri = pkcs11uri.New()
_ = keyUri.AddPathAttribute("model", slot.tokInfo.Model) keyUri.AddPathAttribute("model", slot.tokInfo.Model)
_ = keyUri.AddPathAttribute("manufacturer", slot.tokInfo.ManufacturerID) keyUri.AddPathAttribute("manufacturer", slot.tokInfo.ManufacturerID)
_ = keyUri.AddPathAttribute("serial", slot.tokInfo.SerialNumber) keyUri.AddPathAttribute("serial", slot.tokInfo.SerialNumber)
_ = keyUri.AddPathAttribute("slot-description", slot.info.SlotDescription) keyUri.AddPathAttribute("slot-description", slot.info.SlotDescription)
_ = keyUri.AddPathAttribute("slot-manufacturer", slot.info.ManufacturerID) keyUri.AddPathAttribute("slot-manufacturer", slot.info.ManufacturerID)
if privateKeyObj.id != nil { if privateKeyObj.id != nil {
_ = keyUri.AddPathAttribute("id", string(privateKeyObj.id[:])) keyUri.AddPathAttribute("id", string(privateKeyObj.id[:]))
} }
if privateKeyObj.label != nil { if privateKeyObj.label != nil {
_ = keyUri.AddPathAttribute("object", string(privateKeyObj.label[:])) keyUri.AddPathAttribute("object", string(privateKeyObj.label[:]))
} }
_ = keyUri.AddPathAttribute("type", "private") keyUri.AddPathAttribute("type", "private")
keyUriStr, err = keyUri.Format() // nosemgrep keyUriStr, err = keyUri.Format() // nosemgrep
if err != nil { if err != nil {
keyUriStr = "" keyUriStr = ""
@@ -646,14 +737,17 @@ func getPKCS11Key(module *pkcs11.Ctx, session pkcs11.SessionHandle, loggedIn boo
manufacturerId = slots[0].info.ManufacturerID manufacturerId = slots[0].info.ManufacturerID
if session != 0 { if session != 0 {
if loggedIn { if loggedIn {
_ = module.Logout(session) module.Logout(session)
_ = module.CloseSession(session) module.CloseSession(session)
} }
} }
loggedIn = false loggedIn = false
session = 0 session = 0
} }
} else { } else {
if Debug {
log.Printf("Found %d matching slots for the PKCS#11 key\n", len(slots))
}
// If the URI matched multiple slots *but* one of them is the // If the URI matched multiple slots *but* one of them is the
// one (certSlotNr) that the certificate was found in, then use // one (certSlotNr) that the certificate was found in, then use
// that. // that.
@@ -700,7 +794,7 @@ retry_search:
goto fail goto fail
} }
for true { for true {
sessionPrivateKeyObjects, _, err := module.FindObjects(session, MaxObjectLimit) sessionPrivateKeyObjects, _, err := module.FindObjects(session, MAX_OBJECT_LIMIT)
if err != nil { if err != nil {
goto fail goto fail
} }
@@ -708,7 +802,7 @@ retry_search:
break break
} }
privateKeyObjects = append(privateKeyObjects, sessionPrivateKeyObjects...) privateKeyObjects = append(privateKeyObjects, sessionPrivateKeyObjects...)
if len(sessionPrivateKeyObjects) < MaxObjectLimit { if len(sessionPrivateKeyObjects) < MAX_OBJECT_LIMIT {
break break
} }
} }
@@ -800,8 +894,13 @@ retry_search:
if noKeyUri { if noKeyUri {
_, keyHadLabel := keyUri.GetPathAttribute("object", false) _, keyHadLabel := keyUri.GetPathAttribute("object", false)
if keyHadLabel { if keyHadLabel {
if Debug {
log.Println("unable to find private key with CKA_LABEL;" +
" repeating the search using CKA_ID of the certificate" +
" without requiring a CKA_LABEL match")
}
keyUri.RemovePathAttribute("object") keyUri.RemovePathAttribute("object")
_ = keyUri.SetPathAttribute("id", escapeAll(certObj.id)) keyUri.SetPathAttribute("id", escapeAll(certObj.id))
goto retry_search goto retry_search
} }
} }
@@ -814,10 +913,10 @@ retry_search:
// So that hunting for the key can be more efficient in the future, // So that hunting for the key can be more efficient in the future,
// return a key URI that has CKA_ID and CKA_LABEL appropriately set. // return a key URI that has CKA_ID and CKA_LABEL appropriately set.
if privateKeyObj.id != nil && len(privateKeyObj.id) != 0 { if privateKeyObj.id != nil && len(privateKeyObj.id) != 0 {
_ = keyUri.SetPathAttribute("id", escapeAll(privateKeyObj.id)) keyUri.SetPathAttribute("id", escapeAll(privateKeyObj.id))
} }
if privateKeyObj.label != nil && len(privateKeyObj.label) != 0 { if privateKeyObj.label != nil && len(privateKeyObj.label) != 0 {
_ = keyUri.SetPathAttribute("object", escapeAll(privateKeyObj.label)) keyUri.SetPathAttribute("object", escapeAll(privateKeyObj.label))
} }
return session, userPin, keyUri, keyType, privateKeyObj, keySlot, alwaysAuth, contextSpecificPin, nil return session, userPin, keyUri, keyType, privateKeyObj, keySlot, alwaysAuth, contextSpecificPin, nil
@@ -848,7 +947,8 @@ func getCertificate(module *pkcs11.Ctx, certUri *pkcs11uri.Pkcs11URI, userPin st
return certSlot, slots, session, loggedIn, matchingCerts[0], nil return certSlot, slots, session, loggedIn, matchingCerts[0], nil
} }
func (pkcs11Signer *PKCS11Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { // Implements the crypto.Signer interface and signs the passed in digest
func (pkcs11Signer *PKCS11Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
var ( var (
module *pkcs11.Ctx module *pkcs11.Ctx
session pkcs11.SessionHandle session pkcs11.SessionHandle
@@ -912,14 +1012,15 @@ func (pkcs11Signer *PKCS11Signer) Sign(_ io.Reader, digest []byte, opts crypto.S
cleanUp: cleanUp:
if session != 0 { if session != 0 {
if loggedIn { if loggedIn {
_ = module.Logout(session) module.Logout(session)
} }
_ = module.CloseSession(session) module.CloseSession(session)
} }
return signature, err return signature, err
} }
// Gets the *x509.Certificate associated with this PKCS11Signer.
func (pkcs11Signer *PKCS11Signer) Certificate() (cert *x509.Certificate, err error) { func (pkcs11Signer *PKCS11Signer) Certificate() (cert *x509.Certificate, err error) {
// If there was a certificate chain associated with this Signer, it // If there was a certificate chain associated with this Signer, it
// should've been saved before. // should've been saved before.
@@ -1022,7 +1123,7 @@ func checkPrivateKeyMatchesCert(module *pkcs11.Ctx, session pkcs11.SessionHandle
// "AWS Roles Anywhere Credential Helper PKCS11 Test" || PKCS11_TEST_VERSION || // "AWS Roles Anywhere Credential Helper PKCS11 Test" || PKCS11_TEST_VERSION ||
// MANUFACTURER_ID || SHA256("IAM RA" || PUBLIC_KEY_BYTE_ARRAY) // MANUFACTURER_ID || SHA256("IAM RA" || PUBLIC_KEY_BYTE_ARRAY)
digest := "AWS Roles Anywhere Credential Helper PKCS11 Test" + digest := "AWS Roles Anywhere Credential Helper PKCS11 Test" +
strconv.Itoa(int(Pkcs11TestVersion)) + manufacturerId + string(digestSuffix) strconv.Itoa(int(PKCS11_TEST_VERSION)) + manufacturerId + string(digestSuffix)
digestBytes := []byte(digest) digestBytes := []byte(digest)
hash := sha256.Sum256(digestBytes) hash := sha256.Sum256(digestBytes)
@@ -1139,7 +1240,7 @@ func GetPKCS11Signer(libPkcs11 string, cert *x509.Certificate, certChain []*x509
} }
crtAttributes, err = module.GetAttributeValue(session, certObj.certObject, crtAttributes) crtAttributes, err = module.GetAttributeValue(session, certObj.certObject, crtAttributes)
if err == nil { if err == nil {
_ = certUri.SetPathAttribute("id", escapeAll(crtAttributes[0].Value)) certUri.SetPathAttribute("id", escapeAll(crtAttributes[0].Value))
} }
crtAttributes = []*pkcs11.Attribute{ crtAttributes = []*pkcs11.Attribute{
@@ -1147,7 +1248,7 @@ func GetPKCS11Signer(libPkcs11 string, cert *x509.Certificate, certChain []*x509
} }
crtAttributes, err = module.GetAttributeValue(session, certObj.certObject, crtAttributes) crtAttributes, err = module.GetAttributeValue(session, certObj.certObject, crtAttributes)
if err == nil { if err == nil {
_ = certUri.SetPathAttribute("object", escapeAll(crtAttributes[0].Value)) certUri.SetPathAttribute("object", escapeAll(crtAttributes[0].Value))
} }
if certChain == nil { if certChain == nil {
@@ -1173,7 +1274,7 @@ func GetPKCS11Signer(libPkcs11 string, cert *x509.Certificate, certChain []*x509
} else { } else {
certUriStr, _ := certUri.Format() certUriStr, _ := certUri.Format()
keyUri = pkcs11uri.New() keyUri = pkcs11uri.New()
_ = keyUri.Parse(certUriStr) keyUri.Parse(certUriStr)
noKeyUri = true noKeyUri = true
} }
if _userPin, ok := keyUri.GetQueryAttribute("pin-value", false); ok { if _userPin, ok := keyUri.GetQueryAttribute("pin-value", false); ok {
@@ -1195,18 +1296,18 @@ func GetPKCS11Signer(libPkcs11 string, cert *x509.Certificate, certChain []*x509
switch keyType { switch keyType {
case pkcs11.CKK_EC: case pkcs11.CKK_EC:
signingAlgorithm = aws4X509EcdsaSha256 signingAlgorithm = aws4_x509_ecdsa_sha256
case pkcs11.CKK_RSA: case pkcs11.CKK_RSA:
signingAlgorithm = aws4X509RsaSha256 signingAlgorithm = aws4_x509_rsa_sha256
default: default:
return nil, "", errors.New("unsupported algorithm") return nil, "", errors.New("unsupported algorithm")
} }
if session != 0 { if session != 0 {
if loggedIn { if loggedIn {
_ = module.Logout(session) module.Logout(session)
} }
_ = module.CloseSession(session) module.CloseSession(session)
} }
return &PKCS11Signer{cert, certChain, module, userPin, alwaysAuth, contextSpecificPin, certUri, keyUri, reusePin}, signingAlgorithm, nil return &PKCS11Signer{cert, certChain, module, userPin, alwaysAuth, contextSpecificPin, certUri, keyUri, reusePin}, signingAlgorithm, nil
@@ -1215,11 +1316,11 @@ fail:
if module != nil { if module != nil {
if session != 0 { if session != 0 {
if loggedIn { if loggedIn {
_ = module.Logout(session) module.Logout(session)
} }
_ = module.CloseSession(session) module.CloseSession(session)
} }
_ = module.Finalize() module.Finalize()
module.Destroy() module.Destroy()
} }

View File

@@ -30,10 +30,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash"
"log"
"os"
"strings"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/scrypt" "golang.org/x/crypto/scrypt"
"hash"
"os"
) )
// as defined in https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4 // as defined in https://datatracker.ietf.org/doc/html/rfc8018#appendix-A.4
@@ -236,6 +239,9 @@ func readPKCS8PrivateKey(privateKeyId string) (crypto.PrivateKey, error) {
func readPKCS8EncryptedPrivateKey(privateKeyId string, pkcs8Password []byte) (crypto.PrivateKey, error) { func readPKCS8EncryptedPrivateKey(privateKeyId string, pkcs8Password []byte) (crypto.PrivateKey, error) {
block, err := parseDERFromPEMForPKCS8(privateKeyId, encryptedBlockType) block, err := parseDERFromPEMForPKCS8(privateKeyId, encryptedBlockType)
if err != nil { if err != nil {
if Debug && strings.Contains(err.Error(), `The block type detected is PRIVATE KEY`) {
log.Println("PKCS#8 password provided but block type indicates that one isn't required.")
}
return nil, errors.New("could not parse PEM data") return nil, errors.New("could not parse PEM data")
} }

View File

@@ -55,9 +55,21 @@ var (
// algorithm isn't supported. // algorithm isn't supported.
ErrUnsupportedHash = errors.New("unsupported hash algorithm") ErrUnsupportedHash = errors.New("unsupported hash algorithm")
RolesanywhereSigningName = "rolesanywhere" // Predefined system store names.
// See: https://learn.microsoft.com/en-us/windows/win32/seccrypto/system-store-locations
SystemStoreNames = []string{
"MY",
"Root",
"Trust",
"CA",
}
// Signing name for the IAM Roles Anywhere service
ROLESANYWHERE_SIGNING_NAME = "rolesanywhere"
) )
// Interface that all signers will have to implement
// (as a result, they will also implement crypto.Signer)
type Signer interface { type Signer interface {
Public() crypto.PublicKey Public() crypto.PublicKey
Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error)
@@ -66,6 +78,7 @@ type Signer interface {
Close() Close()
} }
// Container for certificate data returned to the SDK as JSON.
type CertificateData struct { type CertificateData struct {
// Type for the key contained in the certificate. // Type for the key contained in the certificate.
// Passed back to the `sign-string` command // Passed back to the `sign-string` command
@@ -80,6 +93,7 @@ type CertificateData struct {
Algorithms []string `json:"supportedAlgorithms"` Algorithms []string `json:"supportedAlgorithms"`
} }
// Container that adheres to the format of credential_process output as specified by AWS.
type CredentialProcessOutput struct { type CredentialProcessOutput struct {
// This field should be hard-coded to 1 for now. // This field should be hard-coded to 1 for now.
Version int `json:"Version"` Version int `json:"Version"`
@@ -104,15 +118,17 @@ type CertificateContainer struct {
// Define constants used in signing // Define constants used in signing
const ( const (
aws4X509RsaSha256 = "AWS4-X509-RSA-SHA256" aws4_x509_rsa_sha256 = "AWS4-X509-RSA-SHA256"
aws4X509EcdsaSha256 = "AWS4-X509-ECDSA-SHA256" aws4_x509_ecdsa_sha256 = "AWS4-X509-ECDSA-SHA256"
timeFormat = "20060102T150405Z" timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102" shortTimeFormat = "20060102"
xAmzDate = "X-Amz-Date" x_amz_date = "X-Amz-Date"
xAmzX509 = "X-Amz-X509" x_amz_x509 = "X-Amz-X509"
xAmzX509Chain = "X-Amz-X509-Chain" x_amz_x509_chain = "X-Amz-X509-Chain"
x_amz_content_sha256 = "X-Amz-Content-Sha256"
authorization = "Authorization" authorization = "Authorization"
host = "Host" host = "Host"
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
) )
// Headers that aren't included in calculating the signature // Headers that aren't included in calculating the signature
@@ -122,10 +138,11 @@ var ignoredHeaderKeys = map[string]bool{
"X-Amzn-Trace-Id": true, "X-Amzn-Trace-Id": true,
} }
var Debug = false var Debug bool = false
// Prompts the user for their password
func GetPassword(ttyReadFile *os.File, ttyWriteFile *os.File, prompt string, parseErrMsg string) (string, error) { func GetPassword(ttyReadFile *os.File, ttyWriteFile *os.File, prompt string, parseErrMsg string) (string, error) {
_, _ = fmt.Fprintln(ttyWriteFile, prompt) fmt.Fprintln(ttyWriteFile, prompt)
passwordBytes, err := term.ReadPassword(int(ttyReadFile.Fd())) passwordBytes, err := term.ReadPassword(int(ttyReadFile.Fd()))
if err != nil { if err != nil {
return "", errors.New(parseErrMsg) return "", errors.New(parseErrMsg)
@@ -203,17 +220,13 @@ func PasswordPrompt(passwordPromptInput PasswordPromptProps) (string, interface{
if err != nil { if err != nil {
return "", nil, errors.New(parseErrMsg) return "", nil, errors.New(parseErrMsg)
} }
defer func(ttyReadFile *os.File) { defer ttyReadFile.Close()
_ = ttyReadFile.Close()
}(ttyReadFile)
ttyWriteFile, err = os.OpenFile(ttyWritePath, os.O_WRONLY, 0) ttyWriteFile, err = os.OpenFile(ttyWritePath, os.O_WRONLY, 0)
if err != nil { if err != nil {
return "", nil, errors.New(parseErrMsg) return "", nil, errors.New(parseErrMsg)
} }
defer func(ttyWriteFile *os.File) { defer ttyWriteFile.Close()
_ = ttyWriteFile.Close()
}(ttyWriteFile)
// The key has a password, so prompt for it // The key has a password, so prompt for it
password, err = GetPassword(ttyReadFile, ttyWriteFile, prompt, parseErrMsg) password, err = GetPassword(ttyReadFile, ttyWriteFile, prompt, parseErrMsg)
@@ -221,7 +234,7 @@ func PasswordPrompt(passwordPromptInput PasswordPromptProps) (string, interface{
return "", nil, err return "", nil, err
} }
checkPasswordResult, err = checkPassword(password) checkPasswordResult, err = checkPassword(password)
for { for true {
// If we've found the right password, return both it and the result of `checkPassword` // If we've found the right password, return both it and the result of `checkPassword`
if err == nil { if err == nil {
return password, checkPasswordResult, nil return password, checkPasswordResult, nil
@@ -237,8 +250,11 @@ func PasswordPrompt(passwordPromptInput PasswordPromptProps) (string, interface{
} }
return "", nil, err return "", nil, err
} }
return "", nil, err
} }
// Default function to showcase certificate information
func DefaultCertContainerToString(certContainer CertificateContainer) string { func DefaultCertContainerToString(certContainer CertificateContainer) string {
var certStr string var certStr string
@@ -256,6 +272,7 @@ func DefaultCertContainerToString(certContainer CertificateContainer) string {
return certStr return certStr
} }
// CertificateContainerList implements the sort.Interface interface
type CertificateContainerList []CertificateContainer type CertificateContainerList []CertificateContainer
func (certificateContainerList CertificateContainerList) Less(i, j int) bool { func (certificateContainerList CertificateContainerList) Less(i, j int) bool {
@@ -295,7 +312,7 @@ func certMatches(certIdentifier CertIdentifier, cert x509.Certificate) bool {
// } // }
// //
// This is defined in RFC3279 §2.2.3 as well as SEC.1. // This is defined in RFC3279 §2.2.3 as well as SEC.1.
// I can't find anything which mandates DER, but I've seen // I can't find anything which mandates DER but I've seen
// OpenSSL refusing to verify it with indeterminate length. // OpenSSL refusing to verify it with indeterminate length.
func encodeEcdsaSigValue(signature []byte) (out []byte, err error) { func encodeEcdsaSigValue(signature []byte) (out []byte, err error) {
sigLen := len(signature) / 2 sigLen := len(signature) / 2
@@ -318,6 +335,9 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
privateKeyId := opts.PrivateKeyId privateKeyId := opts.PrivateKeyId
if privateKeyId == "" { if privateKeyId == "" {
if opts.CertificateId == "" { if opts.CertificateId == "" {
if Debug {
log.Println("attempting to use CertStoreSigner")
}
return GetCertStoreSigner(opts.CertIdentifier, opts.UseLatestExpiringCertificate) return GetCertStoreSigner(opts.CertIdentifier, opts.UseLatestExpiringCertificate)
} }
privateKeyId = opts.CertificateId privateKeyId = opts.CertificateId
@@ -328,7 +348,9 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
if err == nil { if err == nil {
certificate = cert certificate = cert
} else if opts.PrivateKeyId == "" { } else if opts.PrivateKeyId == "" {
if Debug {
log.Println("not a PEM certificate, so trying PKCS#12")
}
if opts.CertificateBundleId != "" { if opts.CertificateBundleId != "" {
return nil, "", errors.New("can't specify certificate chain when" + return nil, "", errors.New("can't specify certificate chain when" +
" using PKCS#12 files; certificate bundle should be provided" + " using PKCS#12 files; certificate bundle should be provided" +
@@ -353,11 +375,17 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
} }
if strings.HasPrefix(privateKeyId, "pkcs11:") { if strings.HasPrefix(privateKeyId, "pkcs11:") {
if Debug {
log.Println("attempting to use PKCS11Signer")
}
if certificate != nil { if certificate != nil {
opts.CertificateId = "" opts.CertificateId = ""
} }
return GetPKCS11Signer(opts.LibPkcs11, certificate, certificateChain, opts.PrivateKeyId, opts.CertificateId, opts.ReusePin) return GetPKCS11Signer(opts.LibPkcs11, certificate, certificateChain, opts.PrivateKeyId, opts.CertificateId, opts.ReusePin)
} else if strings.HasPrefix(privateKeyId, "handle:") { } else if strings.HasPrefix(privateKeyId, "handle:") {
if Debug {
log.Println("attempting to use TPMv2Signer")
}
return GetTPMv2Signer( return GetTPMv2Signer(
GetTPMv2SignerOpts{ GetTPMv2SignerOpts{
certificate, certificate,
@@ -371,6 +399,9 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
} else { } else {
tpmKey, err := parseDERFromPEM(privateKeyId, "TSS2 PRIVATE KEY") tpmKey, err := parseDERFromPEM(privateKeyId, "TSS2 PRIVATE KEY")
if err == nil { if err == nil {
if Debug {
log.Println("attempting to use TPMv2Signer")
}
return GetTPMv2Signer( return GetTPMv2Signer(
GetTPMv2SignerOpts{ GetTPMv2SignerOpts{
certificate, certificate,
@@ -393,18 +424,24 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
if certificate == nil { if certificate == nil {
return nil, "", errors.New("undefined certificate value") return nil, "", errors.New("undefined certificate value")
} }
if Debug {
log.Println("attempting to use FileSystemSigner")
}
return GetFileSystemSigner(privateKeyId, opts.CertificateId, opts.CertificateBundleId, false, opts.Pkcs8Password) return GetFileSystemSigner(privateKeyId, opts.CertificateId, opts.CertificateBundleId, false, opts.Pkcs8Password)
} }
} }
// Obtain the date-time, formatted as specified by SigV4
func (signerParams *SignerParams) GetFormattedSigningDateTime() string { func (signerParams *SignerParams) GetFormattedSigningDateTime() string {
return signerParams.OverriddenDate.UTC().Format(timeFormat) return signerParams.OverriddenDate.UTC().Format(timeFormat)
} }
// Obtain the short date-time, formatted as specified by SigV4
func (signerParams *SignerParams) GetFormattedShortSigningDateTime() string { func (signerParams *SignerParams) GetFormattedShortSigningDateTime() string {
return signerParams.OverriddenDate.UTC().Format(shortTimeFormat) return signerParams.OverriddenDate.UTC().Format(shortTimeFormat)
} }
// Obtain the scope as part of the SigV4-X509 signature
func (signerParams *SignerParams) GetScope() string { func (signerParams *SignerParams) GetScope() string {
var scopeStringBuilder strings.Builder var scopeStringBuilder strings.Builder
scopeStringBuilder.WriteString(signerParams.GetFormattedShortSigningDateTime()) scopeStringBuilder.WriteString(signerParams.GetFormattedShortSigningDateTime())
@@ -449,14 +486,14 @@ func CreateRequestSignFinalizeFunction(signer crypto.Signer, signingRegion strin
} }
func signRequest(signer crypto.Signer, signingRegion string, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate, req *http.Request, payloadHash string) { func signRequest(signer crypto.Signer, signingRegion string, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate, req *http.Request, payloadHash string) {
signerParams := SignerParams{time.Now(), signingRegion, RolesanywhereSigningName, signingAlgorithm} signerParams := SignerParams{time.Now(), signingRegion, ROLESANYWHERE_SIGNING_NAME, signingAlgorithm}
// Set headers that are necessary for signing // Set headers that are necessary for signing
req.Header.Set(host, req.URL.Host) req.Header.Set(host, req.URL.Host)
req.Header.Set(xAmzDate, signerParams.GetFormattedSigningDateTime()) req.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime())
req.Header.Set(xAmzX509, certificateToString(certificate)) req.Header.Set(x_amz_x509, certificateToString(certificate))
if certificateChain != nil { if certificateChain != nil {
req.Header.Set(xAmzX509Chain, certificateChainToString(certificateChain)) req.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain))
} }
canonicalRequest, signedHeadersString := createCanonicalRequest(req, payloadHash) canonicalRequest, signedHeadersString := createCanonicalRequest(req, payloadHash)
@@ -572,6 +609,7 @@ func createCanonicalRequest(r *http.Request, contentSha256 string) (string, stri
return hex.EncodeToString(canonicalRequestStringHashBytes[:]), signedHeadersString return hex.EncodeToString(canonicalRequestStringHashBytes[:]), signedHeadersString
} }
// Create the string to sign.
func CreateStringToSign(canonicalRequest string, signerParams SignerParams) string { func CreateStringToSign(canonicalRequest string, signerParams SignerParams) string {
var stringToSignStrBuilder strings.Builder var stringToSignStrBuilder strings.Builder
stringToSignStrBuilder.WriteString(signerParams.SigningAlgorithm) stringToSignStrBuilder.WriteString(signerParams.SigningAlgorithm)
@@ -585,7 +623,8 @@ func CreateStringToSign(canonicalRequest string, signerParams SignerParams) stri
return stringToSign return stringToSign
} }
func BuildAuthorizationHeader(_ *http.Request, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string { // Builds the complete authorization header
func BuildAuthorizationHeader(request *http.Request, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string {
signingCredentials := certificate.SerialNumber.String() + "/" + signerParams.GetScope() signingCredentials := certificate.SerialNumber.String() + "/" + signerParams.GetScope()
credential := "Credential=" + signingCredentials credential := "Credential=" + signingCredentials
signerHeaders := "SignedHeaders=" + signedHeadersString signerHeaders := "SignedHeaders=" + signedHeadersString
@@ -606,17 +645,20 @@ func BuildAuthorizationHeader(_ *http.Request, signedHeadersString string, signa
func encodeDer(der []byte) (string, error) { func encodeDer(der []byte) (string, error) {
var buf bytes.Buffer var buf bytes.Buffer
encoder := base64.NewEncoder(base64.StdEncoding, &buf) encoder := base64.NewEncoder(base64.StdEncoding, &buf)
_, _ = encoder.Write(der) encoder.Write(der)
_ = encoder.Close() encoder.Close()
return buf.String(), nil return buf.String(), nil
} }
func parseDERFromPEM(pemDataId string, blockType string) (*pem.Block, error) { func parseDERFromPEM(pemDataId string, blockType string) (*pem.Block, error) {
b := []byte(pemDataId) bytes, err := os.ReadFile(pemDataId)
if err != nil {
return nil, err
}
var block *pem.Block var block *pem.Block
for len(b) > 0 { for len(bytes) > 0 {
block, b = pem.Decode(b) block, bytes = pem.Decode(bytes)
if block == nil { if block == nil {
return nil, errors.New("unable to parse PEM data") return nil, errors.New("unable to parse PEM data")
} }
@@ -627,18 +669,26 @@ func parseDERFromPEM(pemDataId string, blockType string) (*pem.Block, error) {
return nil, errors.New("requested block type could not be found") return nil, errors.New("requested block type could not be found")
} }
// Reads certificate bundle data from a file, whose path is provided
func ReadCertificateBundleData(certificateBundleId string) ([]*x509.Certificate, error) { func ReadCertificateBundleData(certificateBundleId string) ([]*x509.Certificate, error) {
bytes, err := os.ReadFile(certificateBundleId)
if err != nil {
return nil, err
}
var derBytes []byte var derBytes []byte
var block *pem.Block var block *pem.Block
block, _ = pem.Decode([]byte(certificateBundleId)) for len(bytes) > 0 {
block, bytes = pem.Decode(bytes)
if block == nil {
break
}
if block.Type != "CERTIFICATE" { if block.Type != "CERTIFICATE" {
return nil, errors.New("invalid certificate chain") return nil, errors.New("invalid certificate chain")
} }
blockBytes := block.Bytes blockBytes := block.Bytes
derBytes = append(derBytes, blockBytes...) derBytes = append(derBytes, blockBytes...)
}
return x509.ParseCertificates(derBytes) return x509.ParseCertificates(derBytes)
} }
@@ -671,21 +721,30 @@ func readRSAPrivateKey(privateKeyId string) (*rsa.PrivateKey, error) {
return privateKey, nil return privateKey, nil
} }
// Reads and parses a PKCS#12 file (which should contain an end-entity
// certificate (optional), certificate chain (optional), and the key
// associated with the end-entity certificate). The end-entity certificate
// will be the first certificate in the returned chain. This method assumes
// that there is exactly one certificate that doesn't issue any others within
// the container and treats that as the end-entity certificate. Also, the
// order of the other certificates in the chain aren't guaranteed. It's
// also not guaranteed that those certificates form a chain with the
// end-entity certificate either.
func ReadPKCS12Data(certificateId string) (certChain []*x509.Certificate, privateKey crypto.PrivateKey, err error) { func ReadPKCS12Data(certificateId string) (certChain []*x509.Certificate, privateKey crypto.PrivateKey, err error) {
var ( var (
bts []byte bytes []byte
pemBlocks []*pem.Block pemBlocks []*pem.Block
parsedCerts []*x509.Certificate parsedCerts []*x509.Certificate
certMap map[string]*x509.Certificate certMap map[string]*x509.Certificate
endEntityFoundIndex int endEntityFoundIndex int
) )
bts, err = os.ReadFile(certificateId) bytes, err = os.ReadFile(certificateId)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
pemBlocks, err = pkcs12.ToPEM(bts, "") pemBlocks, err = pkcs12.ToPEM(bytes, "")
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@@ -701,6 +760,11 @@ func ReadPKCS12Data(certificateId string) (certChain []*x509.Certificate, privat
privateKey = privateKeyTmp privateKey = privateKeyTmp
continue continue
} }
// If neither a certificate nor a private key could be parsed from the
// Block, ignore it and continue.
if Debug {
log.Println("unable to parse PEM block in PKCS#12 file - skipping")
}
} }
certMap = make(map[string]*x509.Certificate) certMap = make(map[string]*x509.Certificate)
@@ -720,6 +784,10 @@ func ReadPKCS12Data(certificateId string) (certChain []*x509.Certificate, privat
break break
} }
} }
if Debug {
log.Println("no end-entity certificate found in PKCS#12 file")
}
for i, cert := range parsedCerts { for i, cert := range parsedCerts {
if i != endEntityFoundIndex { if i != endEntityFoundIndex {
certChain = append(certChain, cert) certChain = append(certChain, cert)
@@ -729,6 +797,7 @@ func ReadPKCS12Data(certificateId string) (certChain []*x509.Certificate, privat
return certChain, privateKey, nil return certChain, privateKey, nil
} }
// Load the private key referenced by `privateKeyId`. If `pkcs8Password` is provided, attempt to load an encrypted PKCS#8 key.
func ReadPrivateKeyData(privateKeyId string, pkcs8Password ...string) (crypto.PrivateKey, error) { func ReadPrivateKeyData(privateKeyId string, pkcs8Password ...string) (crypto.PrivateKey, error) {
if len(pkcs8Password) > 0 && pkcs8Password[0] != "" { if len(pkcs8Password) > 0 && pkcs8Password[0] != "" {
if key, err := readPKCS8EncryptedPrivateKey(privateKeyId, []byte(pkcs8Password[0])); err == nil { if key, err := readPKCS8EncryptedPrivateKey(privateKeyId, []byte(pkcs8Password[0])); err == nil {
@@ -753,6 +822,7 @@ func ReadPrivateKeyData(privateKeyId string, pkcs8Password ...string) (crypto.Pr
return nil, errors.New("unable to parse private key") return nil, errors.New("unable to parse private key")
} }
// Reads private key data from a *pem.Block.
func ReadPrivateKeyDataFromPEMBlock(block *pem.Block) (key crypto.PrivateKey, err error) { func ReadPrivateKeyDataFromPEMBlock(block *pem.Block) (key crypto.PrivateKey, err error) {
key, err = x509.ParseECPrivateKey(block.Bytes) key, err = x509.ParseECPrivateKey(block.Bytes)
if err == nil { if err == nil {

View File

@@ -423,9 +423,9 @@ func GetTPMv2Signer(opts GetTPMv2SignerOpts) (signer Signer, signingAlgorithm st
switch public.Type { switch public.Type {
case tpm2.AlgRSA: case tpm2.AlgRSA:
signingAlgorithm = aws4X509RsaSha256 signingAlgorithm = aws4_x509_rsa_sha256
case tpm2.AlgECC: case tpm2.AlgECC:
signingAlgorithm = aws4X509EcdsaSha256 signingAlgorithm = aws4_x509_ecdsa_sha256
default: default:
return nil, "", errors.New("unsupported TPMv2 key type") return nil, "", errors.New("unsupported TPMv2 key type")
} }

View File

@@ -1,11 +1,6 @@
package config package config
import ( import "git.siteworxpro.com/packages/go/utilities/Env"
"encoding/base64"
"fmt"
"git.siteworxpro.com/packages/go/utilities/Env"
"regexp"
)
const ( const (
namespace Env.EnvironmentVariable = "NAMESPACE" namespace Env.EnvironmentVariable = "NAMESPACE"
@@ -15,10 +10,8 @@ const (
trustedAnchorArn Env.EnvironmentVariable = "TRUSTED_ANCHOR_ARN" trustedAnchorArn Env.EnvironmentVariable = "TRUSTED_ANCHOR_ARN"
privateKey Env.EnvironmentVariable = "PRIVATE_KEY" privateKey Env.EnvironmentVariable = "PRIVATE_KEY"
certificate Env.EnvironmentVariable = "CERTIFICATE" certificate Env.EnvironmentVariable = "CERTIFICATE"
bundleId Env.EnvironmentVariable = "CA_CHAIN"
sessionDuration Env.EnvironmentVariable = "SESSION_DURATION" sessionDuration Env.EnvironmentVariable = "SESSION_DURATION"
restartDeployments Env.EnvironmentVariable = "RESTART_DEPLOYMENTS" restartDeployments Env.EnvironmentVariable = "RESTART_DEPLOYMENTS"
fetchOnly Env.EnvironmentVariable = "FETCH_ONLY"
) )
type Config struct{} type Config struct{}
@@ -27,59 +20,6 @@ func NewConfig() *Config {
return &Config{} return &Config{}
} }
func (c Config) Valid() error {
// Certificate Required
if c.Certificate() == "" {
return fmt.Errorf("certificate is required")
}
// Private Key Required
if c.PrivateKey() == "" {
return fmt.Errorf("private Key is required")
}
// Role ARN Required
if c.RoleArn() == "" {
return fmt.Errorf("role ARN is required")
}
if !regexp.MustCompile(`^arn:aws:iam::[0-9]{10,13}:role/[\w\D]*$`).MatchString(c.RoleArn()) {
return fmt.Errorf("role ARN %s is invalid", c.RoleArn())
}
if c.ProfileArn() == "" {
return fmt.Errorf("profile ARN is required")
}
if !regexp.MustCompile(`^arn:aws:rolesanywhere:[\w-]*:\d{10,12}:profile/[\w\D]*$`).MatchString(c.ProfileArn()) {
return fmt.Errorf("profile ARN %s is invalid", c.ProfileArn())
}
// Trusted Anchor ARN Required
if c.TrustedAnchor() == "" {
return fmt.Errorf("trusted anchor ARN is required")
}
if !regexp.MustCompile(`^arn:aws:rolesanywhere:[\w-]*:\d{10,12}:trust-anchor/[\w\D]*$`).MatchString(c.TrustedAnchor()) {
return fmt.Errorf("trusted anchor %s ARN is invalid", c.TrustedAnchor())
}
return nil
}
func (Config) BundleId() string {
v, err := base64.StdEncoding.DecodeString(bundleId.GetEnvString(""))
if err != nil {
return ""
}
return string(v)
}
func (Config) FetchOnly() bool {
return fetchOnly.GetEnvBool(false)
}
func (Config) Namespace() string { func (Config) Namespace() string {
return namespace.GetEnvString("") return namespace.GetEnvString("")
} }
@@ -101,21 +41,11 @@ func (Config) TrustedAnchor() string {
} }
func (Config) PrivateKey() string { func (Config) PrivateKey() string {
v, err := base64.StdEncoding.DecodeString(privateKey.GetEnvString("")) return privateKey.GetEnvString("")
if err != nil {
return ""
}
return string(v)
} }
func (Config) Certificate() string { func (Config) Certificate() string {
v, err := base64.StdEncoding.DecodeString(certificate.GetEnvString("")) return certificate.GetEnvString("")
if err != nil {
return ""
}
return string(v)
} }
func (Config) SessionDuration() int64 { func (Config) SessionDuration() int64 {

22
go.mod
View File

@@ -6,7 +6,7 @@ require (
git.siteworxpro.com/packages/go/utilities v1.3.0 git.siteworxpro.com/packages/go/utilities v1.3.0
github.com/aws/aws-sdk-go v1.55.7 github.com/aws/aws-sdk-go v1.55.7
github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2/config v1.29.14 github.com/aws/aws-sdk-go-v2/config v1.29.6
github.com/aws/rolesanywhere-credential-helper v1.6.0 github.com/aws/rolesanywhere-credential-helper v1.6.0
github.com/aws/smithy-go v1.22.3 github.com/aws/smithy-go v1.22.3
github.com/charmbracelet/log v0.4.2 github.com/charmbracelet/log v0.4.2
@@ -22,16 +22,16 @@ require (
) )
require ( require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.59 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.28 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.32 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.32 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.13 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.15 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.14 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/colorprofile v0.3.1 // indirect github.com/charmbracelet/colorprofile v0.3.1 // indirect
github.com/charmbracelet/lipgloss v1.1.0 // indirect github.com/charmbracelet/lipgloss v1.1.0 // indirect

44
go.sum
View File

@@ -10,28 +10,28 @@ github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE
github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= github.com/aws/aws-sdk-go-v2/config v1.29.6 h1:fqgqEKK5HaZVWLQoLiC9Q+xDlSp+1LYidp6ybGE2OGg=
github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= github.com/aws/aws-sdk-go-v2/config v1.29.6/go.mod h1:Ft+WLODzDQmCTHDvqAH1JfC2xxbZ0MxpZAcJqmE1LTQ=
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= github.com/aws/aws-sdk-go-v2/credentials v1.17.59 h1:9btwmrt//Q6JcSdgJOLI98sdr5p7tssS9yAsGe8aKP4=
github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= github.com/aws/aws-sdk-go-v2/credentials v1.17.59/go.mod h1:NM8fM6ovI3zak23UISdWidyZuI1ghNe2xjzUZAyT+08=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.28 h1:KwsodFKVQTlI5EyhRSugALzsV6mG/SGrdjlMXSZSdso=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.28/go.mod h1:EY3APf9MzygVhKuPXAc5H+MkGb8k/DOSQjWS0LgkKqI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.32 h1:BjUcr3X3K0wZPGFg2bxOWW3VPN8rkE3/61zhP+IHviA=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.32/go.mod h1:80+OGC/bgzzFFTUmcuwD0lb4YutwQeKLFpmt6hoWapU=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.32 h1:m1GeXHVMJsRsUAqG6HjZWx9dj7F5TR+cF1bjyfYyBd4=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.32/go.mod h1:IitoQxGfaKdVLNg0hD8/DXmAqNy0H4K2H2Sf91ti8sI=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.13 h1:SYVGSFQHlchIcy6e7x12bsrxClCXSP5et8cqVhL8cuw=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.13/go.mod h1:kizuDaLX37bG5WZaoxGPQR/LNFXpxp0vsUnqfkWXfNE=
github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= github.com/aws/aws-sdk-go-v2/service/sso v1.24.15 h1:/eE3DogBjYlvlbhd2ssWyeuovWunHLxfgw3s/OJa4GQ=
github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= github.com/aws/aws-sdk-go-v2/service/sso v1.24.15/go.mod h1:2PCJYpi7EKeA5SkStAmZlF6fi0uUABuhtF8ILHjGc3Y=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.14 h1:M/zwXiL2iXUrHputuXgmO94TVNmcenPHxgLXLutodKE=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.14/go.mod h1:RVwIw3y/IqxC2YEXSIkAzRDdEU1iRabDPaYjpGCbCGQ=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 h1:TzeR06UCMUq+KA3bDkujxK1GVGy+G8qQN/QVYzGLkQE=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= github.com/aws/aws-sdk-go-v2/service/sts v1.33.14/go.mod h1:dspXf/oYWGWo6DEvj98wpaTeqt5+DMidZD0A9BYTizc=
github.com/aws/rolesanywhere-credential-helper v1.6.0 h1:NX9Qc1jQ85XzF5Ksm5DKLdKXEUj5szdIDbGsglYCBaQ= github.com/aws/rolesanywhere-credential-helper v1.6.0 h1:NX9Qc1jQ85XzF5Ksm5DKLdKXEUj5szdIDbGsglYCBaQ=
github.com/aws/rolesanywhere-credential-helper v1.6.0/go.mod h1:h2qTbudK5O3KD5FtlIPgkmCB16oeebp9g/43pn5TEGU= github.com/aws/rolesanywhere-credential-helper v1.6.0/go.mod h1:h2qTbudK5O3KD5FtlIPgkmCB16oeebp9g/43pn5TEGU=
github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k=

43
main.go
View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"encoding/base64"
helper "gitea.siteworxpro.com/Siteworxpro/aws-iam-anywhere-refresher/aws_signing_helper" helper "gitea.siteworxpro.com/Siteworxpro/aws-iam-anywhere-refresher/aws_signing_helper"
"gitea.siteworxpro.com/Siteworxpro/aws-iam-anywhere-refresher/cmd" "gitea.siteworxpro.com/Siteworxpro/aws-iam-anywhere-refresher/cmd"
appConfig "gitea.siteworxpro.com/Siteworxpro/aws-iam-anywhere-refresher/config" appConfig "gitea.siteworxpro.com/Siteworxpro/aws-iam-anywhere-refresher/config"
@@ -17,21 +18,35 @@ func main() {
ReportTimestamp: true, ReportTimestamp: true,
TimeFormat: time.RFC3339, TimeFormat: time.RFC3339,
}) })
l.Info("Starting credentials refresh") l.Info("Starting credentials refresh")
client, err := kube_client.NewKubeClient()
if err != nil {
l.Error("Failed to create kubernetes client", "error", err)
os.Exit(1)
}
c := appConfig.NewConfig() c := appConfig.NewConfig()
err := c.Valid() privateKey, err := base64.StdEncoding.DecodeString(c.PrivateKey())
if err != nil { if err != nil {
l.Error("Invalid configuration", "error", err) l.Error("Failed to decode private key", "error", err)
os.Exit(1)
}
certificate, err := base64.StdEncoding.DecodeString(c.Certificate())
if err != nil {
l.Error("Failed to decode certificate", "error", err)
os.Exit(1) os.Exit(1)
} }
credentials, err := cmd.Run(&helper.CredentialsOpts{ credentials, err := cmd.Run(&helper.CredentialsOpts{
PrivateKeyId: c.PrivateKey(), PrivateKeyId: string(privateKey),
CertificateId: c.Certificate(), CertificateId: string(certificate),
CertificateBundleId: c.BundleId(), CertIdentifier: helper.CertIdentifier{
SystemStoreName: "MY",
},
RoleArn: c.RoleArn(), RoleArn: c.RoleArn(),
ProfileArnStr: c.ProfileArn(), ProfileArnStr: c.ProfileArn(),
TrustAnchorArnStr: c.TrustedAnchor(), TrustAnchorArnStr: c.TrustedAnchor(),
@@ -46,22 +61,6 @@ func main() {
l.Info("Credentials refreshed") l.Info("Credentials refreshed")
if c.FetchOnly() {
l.Info("Fetch only mode, skipping secret update")
l.Info("AccessKeyId", "access-key-id", credentials.AccessKeyId)
l.Info("SecretAccessKey", "secret-access-key", credentials.SecretAccessKey)
l.Info("SessionToken", "session-token", credentials.SessionToken)
os.Exit(0)
}
client, err := kube_client.NewKubeClient()
if err != nil {
l.Error("Failed to create kubernetes client", "error", err)
os.Exit(1)
}
_, err = client.GetSecret(c.Namespace(), c.Secret()) _, err = client.GetSecret(c.Namespace(), c.Secret())
if err != nil { if err != nil {
l.Error("Failed to get secret", "error", err) l.Error("Failed to get secret", "error", err)