Skip to content

Commit

Permalink
Merge pull request kubernetes-sigs#53 from mattlandis/add-sts-error-m…
Browse files Browse the repository at this point in the history
…etric

Add errors types to Verify to differentiate between token and STS errors
  • Loading branch information
nckturner authored Feb 28, 2018
2 parents 4497817 + 0375c9b commit b53aa08
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 23 deletions.
7 changes: 6 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const (
metricNS = "heptio_authenticator_aws"
metricMalformed = "malformed_request"
metricInvalid = "invalid_token"
metricSTSError = "sts_error"
metricUnknown = "uknown_user"
metricSuccess = "success"
)
Expand Down Expand Up @@ -212,7 +213,11 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)
// if the token is invalid, reject with a 403
identity, err := h.verifier.Verify(tokenReview.Spec.Token)
if err != nil {
h.metrics.latency.WithLabelValues(metricInvalid).Observe(duration(start))
if _, ok := err.(token.STSError); ok {
h.metrics.latency.WithLabelValues(metricSTSError).Observe(duration(start))
} else {
h.metrics.latency.WithLabelValues(metricInvalid).Observe(duration(start))
}
log.WithError(err).Warn("access denied")
w.WriteHeader(http.StatusForbidden)
w.Write(tokenReviewDenyJSON)
Expand Down
29 changes: 27 additions & 2 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func cleanup(m metrics) {
// Count of expected metrics
type validateOpts struct {
// The expected number of latency entries for each label.
malformed, invalidToken, unknownUser, success uint64
malformed, invalidToken, unknownUser, success, stsError uint64
}

func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) {
Expand All @@ -89,7 +89,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
}
for _, m := range metrics {
if strings.HasPrefix(m.GetName(), "heptio_authenticator_aws_authenticate_latency_seconds") {
var actualSuccess, actualMalformed, actualInvalid, actualUnknown uint64
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64
for _, metric := range m.GetMetric() {
if len(metric.Label) != 1 {
t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label)
Expand All @@ -107,6 +107,8 @@ func validateMetrics(t *testing.T, opts validateOpts) {
actualInvalid = metric.GetHistogram().GetSampleCount()
case metricUnknown:
actualUnknown = metric.GetHistogram().GetSampleCount()
case metricSTSError:
actualSTSError = metric.GetHistogram().GetSampleCount()
default:
t.Errorf("Unknown result for latency label: %s", *label.Value)

Expand All @@ -116,6 +118,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
checkHistogramSampleCount(t, metricMalformed, actualMalformed, opts.malformed)
checkHistogramSampleCount(t, metricInvalid, actualInvalid, opts.invalidToken)
checkHistogramSampleCount(t, metricUnknown, actualUnknown, opts.unknownUser)
checkHistogramSampleCount(t, metricSTSError, actualSTSError, opts.stsError)
}
}
}
Expand Down Expand Up @@ -192,6 +195,28 @@ func TestAuthenticateVerifierError(t *testing.T) {
validateMetrics(t, validateOpts{invalidToken: 1})
}

func TestAuthenticateVerifierSTSError(t *testing.T) {
resp := httptest.NewRecorder()

data, err := json.Marshal(authenticationv1beta1.TokenReview{
Spec: authenticationv1beta1.TokenReviewSpec{
Token: "token",
},
})
if err != nil {
t.Fatalf("Could not marshal in put data: %v", err)
}
req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data))
h := setup(&testVerifier{err: token.NewSTSError("There was an error")})
defer cleanup(h.metrics)
h.authenticateEndpoint(resp, req)
if resp.Code != http.StatusForbidden {
t.Errorf("Expected status code %d, was %d", http.StatusForbidden, resp.Code)
}
verifyBodyContains(t, resp, string(tokenReviewDenyJSON))
validateMetrics(t, validateOpts{stsError: 1})
}

func TestAuthenticateVerifierNotMapped(t *testing.T) {
resp := httptest.NewRecorder()

Expand Down
66 changes: 46 additions & 20 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ const (
clusterIDHeader = "x-k8s-aws-id"
)

// FormatError is returned when there is a problem with token that is
// an encoded sts request. This can include the url, data, action or anything
// else that prevents the sts call from being made.
type FormatError struct {
message string
}

func (e FormatError) Error() string {
return "input token was not properly formatted: " + e.message
}

// STSError is returned when there was either an error calling STS or a problem
// processing the data returned from STS.
type STSError struct {
message string
}

func (e STSError) Error() string {
return "sts getCallerIdentity failed: " + e.message
}

// NewSTSError creates a error of type STS.
func NewSTSError(m string) STSError {
return STSError{message: m}
}

var parameterWhitelist = map[string]bool{
"action": true,
"version": true,
Expand Down Expand Up @@ -177,59 +203,59 @@ func NewVerifier(clusterID string) Verifier {
// token. On failure, returns nil and a non-nil error.
func (v tokenVerifier) Verify(token string) (*Identity, error) {
if len(token) > maxTokenLenBytes {
return nil, fmt.Errorf("token is too large")
return nil, FormatError{"token is too large"}
}

if !strings.HasPrefix(token, v1Prefix) {
return nil, fmt.Errorf("token is missing expected %q prefix", v1Prefix)
return nil, FormatError{fmt.Sprintf("token is missing expected %q prefix", v1Prefix)}
}

// TODO: this may need to be a constant-time base64 decoding
tokenBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(token, v1Prefix))
if err != nil {
return nil, err
return nil, FormatError{err.Error()}
}

parsedURL, err := url.Parse(string(tokenBytes))
if err != nil {
return nil, err
return nil, FormatError{err.Error()}
}

if parsedURL.Scheme != "https" {
return nil, fmt.Errorf("unexpected scheme %q in pre-signed URL", parsedURL.Scheme)
return nil, FormatError{fmt.Sprintf("unexpected scheme %q in pre-signed URL", parsedURL.Scheme)}
}

if parsedURL.Host != "sts.amazonaws.com" {
return nil, fmt.Errorf("unexpected hostname in pre-signed URL")
return nil, FormatError{"unexpected hostname in pre-signed URL"}
}

if parsedURL.Path != "/" {
return nil, fmt.Errorf("unexpected path in pre-signed URL")
return nil, FormatError{"unexpected path in pre-signed URL"}
}

queryParamsLower := make(url.Values)
queryParams := parsedURL.Query()
for key, values := range queryParams {
if !parameterWhitelist[strings.ToLower(key)] {
return nil, fmt.Errorf("non-whitelisted query parameter %q", key)
return nil, FormatError{fmt.Sprintf("non-whitelisted query parameter %q", key)}
}
if len(values) != 1 {
return nil, fmt.Errorf("query parameter with multiple values not supported")
return nil, FormatError{"query parameter with multiple values not supported"}
}
queryParamsLower.Set(strings.ToLower(key), values[0])
}

if queryParamsLower.Get("action") != "GetCallerIdentity" {
return nil, fmt.Errorf("unexpected action parameter in pre-signed URL")
return nil, FormatError{"unexpected action parameter in pre-signed URL"}
}

if !hasSignedClusterIDHeader(&queryParamsLower) {
return nil, fmt.Errorf("client did not sign the %s header in the pre-signed URL", clusterIDHeader)
return nil, FormatError{fmt.Sprintf("client did not sign the %s header in the pre-signed URL", clusterIDHeader)}
}

expires, err := strconv.Atoi(queryParamsLower.Get("x-amz-expires"))
if err != nil || expires < 0 || expires > 60 {
return nil, fmt.Errorf("invalid X-Amz-Expires parameter in pre-signed URL")
return nil, FormatError{"invalid X-Amz-Expires parameter in pre-signed URL"}
}

req, err := http.NewRequest("GET", parsedURL.String(), nil)
Expand All @@ -240,25 +266,25 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
if err != nil {
// special case to avoid printing the full URL if possible
if urlErr, ok := err.(*url.Error); ok {
return nil, fmt.Errorf("error during GET: %v", urlErr.Err)
return nil, NewSTSError(fmt.Sprintf("error during GET: %v", urlErr.Err))
}
return nil, fmt.Errorf("error during GET: %v", err)
return nil, NewSTSError(fmt.Sprintf("error during GET: %v", err))
}
defer response.Body.Close()

if response.StatusCode != 200 {
return nil, fmt.Errorf("error from AWS (expected 200, got %d)", response.StatusCode)
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d)", response.StatusCode))
}

responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("error reading HTTP result: %v", err)
return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err))
}

var callerIdentity getCallerIdentityWrapper
err = json.Unmarshal(responseBody, &callerIdentity)
if err != nil {
return nil, err
return nil, NewSTSError(err.Error())
}

// parse the response into an Identity
Expand All @@ -268,7 +294,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
}
id.CanonicalARN, err = canonicalizeARN(id.ARN)
if err != nil {
return nil, err
return nil, NewSTSError(err.Error())
}

// The user ID is either UserID:SessionName (for assumed roles) or just
Expand All @@ -280,9 +306,9 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
} else if len(userIDParts) == 1 {
id.UserID = userIDParts[0]
} else {
return nil, fmt.Errorf(
return nil, STSError{fmt.Sprintf(
"malformed UserID %q",
callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)
callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)}
}

return id, nil
Expand Down
15 changes: 15 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,25 @@ import (
)

func validationErrorTest(t *testing.T, token string, expectedErr string) {
t.Helper()
_, err := tokenVerifier{}.Verify(token)
errorContains(t, err, expectedErr)
}

func errorContains(t *testing.T, err error, expectedErr string) {
t.Helper()
if err == nil || !strings.Contains(err.Error(), expectedErr) {
t.Errorf("err should have contained '%s' was '%s'", expectedErr, err)
}
}

func assertSTSError(t *testing.T, err error) {
t.Helper()
if _, ok := err.(STSError); !ok {
t.Errorf("Expected err %v to be an STSError but was not", err)
}
}

const validURL = "https://sts.amazonaws.com/?action=GetCallerIdentity&x-amz-signedheaders=x-k8s-aws-id&x-amz-expires=60"

var validToken = toToken(validURL)
Expand Down Expand Up @@ -98,11 +107,13 @@ func TestVerifyTokenPreSTSValidations(t *testing.T) {
func TestVerifyHTTPError(t *testing.T) {
_, err := newVerifier(0, "", errors.New("an error")).Verify(validToken)
errorContains(t, err, "error during GET: an error")
assertSTSError(t, err)
}

func TestVerifyHTTP403(t *testing.T) {
_, err := newVerifier(403, " ", nil).Verify(validToken)
errorContains(t, err, "error from AWS (expected 200, got")
assertSTSError(t, err)
}

func TestVerifyBodyReadError(t *testing.T) {
Expand All @@ -119,21 +130,25 @@ func TestVerifyBodyReadError(t *testing.T) {
}
_, err := verifier.Verify(validToken)
errorContains(t, err, "error reading HTTP result")
assertSTSError(t, err)
}

func TestVerifyUnmarshalJSONError(t *testing.T) {
_, err := newVerifier(200, "xxxx", nil).Verify(validToken)
errorContains(t, err, "invalid character")
assertSTSError(t, err)
}

func TestVerifyInvalidCanonicalARNError(t *testing.T) {
_, err := newVerifier(200, jsonResponse("arn", "1000", "userid"), nil).Verify(validToken)
errorContains(t, err, "malformed ARN")
assertSTSError(t, err)
}

func TestVerifyInvalidUserIDError(t *testing.T) {
_, err := newVerifier(200, jsonResponse("arn:aws:iam::123456789012:user/Alice", "123456789012", "not:vailid:userid"), nil).Verify(validToken)
errorContains(t, err, "malformed UserID")
assertSTSError(t, err)
}

func TestVerifyNoSession(t *testing.T) {
Expand Down

0 comments on commit b53aa08

Please sign in to comment.