Skip to content

Commit

Permalink
fix get aws creds from environment (#3617)
Browse files Browse the repository at this point in the history
Signed-off-by: Fabian Martinez <46371672+famarting@users.noreply.github.com>
  • Loading branch information
famarting authored Nov 28, 2024
1 parent f48b412 commit 1e095ed
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 25 deletions.
7 changes: 0 additions & 7 deletions common/authentication/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ type Provider interface {
Close() error
}

func isX509Auth(m map[string]string) bool {
tp, _ := m["trustProfileArn"]
ta, _ := m["trustAnchorArn"]
ar, _ := m["assumeRoleArn"]
return tp != "" && ta != "" && ar != ""
}

func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) {
if isX509Auth(opts.Properties) {
return newX509(ctx, opts, cfg)
Expand Down
48 changes: 32 additions & 16 deletions common/authentication/aws/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ type StaticAuth struct {
endpoint *string
accessKey *string
secretKey *string
sessionToken *string
sessionToken string

assumeRoleARN *string
sessionName *string
sessionName string

session *session.Session
cfg *aws.Config
Expand All @@ -50,15 +50,7 @@ type StaticAuth struct {

func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) {
auth := &StaticAuth{
logger: opts.Logger,
region: &opts.Region,
endpoint: &opts.Endpoint,
accessKey: &opts.AccessKey,
secretKey: &opts.SecretKey,
sessionToken: &opts.SessionToken,
assumeRoleARN: &opts.AssumeRoleARN,
sessionName: &opts.SessionName,

logger: opts.Logger,
cfg: func() *aws.Config {
// if nil is passed or it's just a default cfg,
// then we use the options to build the aws cfg.
Expand All @@ -70,7 +62,29 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth
clients: newClients(),
}

initialSession, err := auth.getTokenClient()
if opts.Region != "" {
auth.region = &opts.Region
}
if opts.Endpoint != "" {
auth.endpoint = &opts.Endpoint
}
if opts.AccessKey != "" {
auth.accessKey = &opts.AccessKey
}
if opts.SecretKey != "" {
auth.secretKey = &opts.SecretKey
}
if opts.SessionToken != "" {
auth.sessionToken = opts.SessionToken
}
if opts.AssumeRoleARN != "" {
auth.assumeRoleARN = &opts.AssumeRoleARN
}
if opts.SessionName != "" {
auth.sessionName = opts.SessionName
}

initialSession, err := auth.createSession()
if err != nil {
return nil, fmt.Errorf("failed to get token client: %v", err)
}
Expand Down Expand Up @@ -231,8 +245,8 @@ func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) {
if a.assumeRoleARN != nil {
tokenProvider.awsIamRoleArn = *a.assumeRoleARN
}
if a.sessionName != nil {
tokenProvider.awsStsSessionName = *a.sessionName
if a.sessionName != "" {
tokenProvider.awsStsSessionName = a.sessionName
}

err := a.clients.kafka.New(a.session, &tokenProvider)
Expand All @@ -243,7 +257,7 @@ func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) {
return a.clients.kafka, nil
}

func (a *StaticAuth) getTokenClient() (*session.Session, error) {
func (a *StaticAuth) createSession() (*session.Session, error) {
var awsConfig *aws.Config
if a.cfg == nil {
awsConfig = aws.NewConfig()
Expand All @@ -257,13 +271,15 @@ func (a *StaticAuth) getTokenClient() (*session.Session, error) {

if a.accessKey != nil && a.secretKey != nil {
// session token is an option field
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken))
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, a.sessionToken))
}

if a.endpoint != nil {
awsConfig = awsConfig.WithEndpoint(*a.endpoint)
}

// TODO support assume role for all aws components

awsSession, err := session.NewSessionWithOptions(session.Options{
Config: *awsConfig,
SharedConfigState: session.SharedConfigEnable,
Expand Down
10 changes: 8 additions & 2 deletions common/authentication/aws/static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,22 @@ func TestGetTokenClient(t *testing.T) {
awsInstance: &StaticAuth{
accessKey: aws.String("testAccessKey"),
secretKey: aws.String("testSecretKey"),
sessionToken: aws.String("testSessionToken"),
sessionToken: "testSessionToken",
region: aws.String("us-west-2"),
endpoint: aws.String("https://test.endpoint.com"),
},
},
{
name: "creds from environment",
awsInstance: &StaticAuth{
region: aws.String("us-west-2"),
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session, err := tt.awsInstance.getTokenClient()
session, err := tt.awsInstance.createSession()
require.NotNil(t, session)
require.NoError(t, err)
assert.Equal(t, tt.awsInstance.region, session.Config.Region)
Expand Down
7 changes: 7 additions & 0 deletions common/authentication/aws/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ import (
"github.com/dapr/kit/ptr"
)

func isX509Auth(m map[string]string) bool {
tp := m["trustProfileArn"]
ta := m["trustAnchorArn"]
ar := m["assumeRoleArn"]
return tp != "" && ta != "" && ar != ""
}

type x509Options struct {
TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"`
TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"`
Expand Down

0 comments on commit 1e095ed

Please sign in to comment.