From a2e1ee0f78e5d539cccc5cdbfff90ee5e51fbb4a Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Tue, 7 Nov 2023 10:00:11 +0100 Subject: [PATCH] SNOW-961482: Replace OCSP cache with more structured version --- driver_ocsp_test.go | 2 +- driver_test.go | 1 + ocsp.go | 115 +++++++++++++++++++++++++------------------- ocsp_test.go | 24 +++------ 4 files changed, 73 insertions(+), 69 deletions(-) diff --git a/driver_ocsp_test.go b/driver_ocsp_test.go index 4b331742f..f784cd8bd 100644 --- a/driver_ocsp_test.go +++ b/driver_ocsp_test.go @@ -35,7 +35,7 @@ func deleteOCSPCacheFile() { func deleteOCSPCacheAll() { ocspResponseCacheLock.Lock() defer ocspResponseCacheLock.Unlock() - ocspResponseCache = make(map[certIDKey][]interface{}) + ocspResponseCache = make(map[certIDKey]*certCacheValue) } func cleanup() { diff --git a/driver_test.go b/driver_test.go index 1d6662ebe..5f31b14d7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -128,6 +128,7 @@ func setup() (string, error) { return "", fmt.Errorf("failed to create schema. %v", err) } createDSN("UTC") + GetLogger().SetLogLevel("debug") return orgSchemaname, nil } diff --git a/ocsp.go b/ocsp.go index 8fc7cef30..a9ae9fac7 100644 --- a/ocsp.go +++ b/ocsp.go @@ -141,8 +141,13 @@ type certIDKey struct { SerialNumber string } +type certCacheValue struct { + ts float64 + certBase64 string +} + var ( - ocspResponseCache map[certIDKey][]interface{} + ocspResponseCache map[certIDKey]*certCacheValue ocspResponseCacheLock *sync.RWMutex ) @@ -270,8 +275,14 @@ func checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Cert return &ocspStatus{code: ocspNoServer} } ocspResponseCacheLock.RLock() - gotValueFromCache := ocspResponseCache[*encodedCertID] - ocspResponseCacheLock.RUnlock() + defer ocspResponseCacheLock.RUnlock() + gotValueFromCache, ok := ocspResponseCache[*encodedCertID] + if !ok { + return &ocspStatus{ + code: ocspMissedCache, + err: fmt.Errorf("miss cache data. subject: %v", subject), + } + } status := extractOCSPCacheResponseValue(gotValueFromCache, subject, issuer) if !isValidOCSPStatus(status.code) { @@ -354,7 +365,7 @@ func checkOCSPCacheServer( req requestFunc, ocspServerHost *url.URL, totalTimeout time.Duration) ( - cacheContent *map[string][]interface{}, + cacheContent *map[string]*certCacheValue, ocspS *ocspStatus) { var respd map[string][]interface{} headers := make(map[string]string) @@ -388,7 +399,21 @@ func checkOCSPCacheServer( } } } - return &respd, &ocspStatus{ + buf := make(map[string]*certCacheValue) + for key, value := range respd { + ts, ok := value[0].(float64) + if !ok { + logger.Warnf("cannot cast %v to float64", value[0]) + continue + } + certBase64, ok := value[1].(string) + if !ok { + logger.Warnf("cannot cast %v to string", value[1]) + continue + } + buf[key] = &certCacheValue{ts, certBase64} + } + return &buf, &ocspStatus{ code: ocspSuccess, } } @@ -608,7 +633,7 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) if !isValidOCSPStatus(ret.code) { return ret // return invalid } - v := []interface{}{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)} + v := &certCacheValue{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)} ocspResponseCacheLock.Lock() ocspResponseCache[*encodedCertID] = v cacheUpdated = true @@ -804,7 +829,7 @@ func initOCSPCache() { if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") { return } - ocspResponseCache = make(map[certIDKey][]interface{}) + ocspResponseCache = make(map[certIDKey]*certCacheValue) ocspResponseCacheLock = &sync.RWMutex{} logger.Infof("reading OCSP Response cache file. %v\n", cacheFileName) @@ -828,74 +853,64 @@ func initOCSPCache() { } } for k, cacheValue := range buf { - status := extractOCSPCacheResponseValueWithoutSubject(cacheValue) + ts, ok := cacheValue[0].(float64) + if !ok { + logger.Warnf("cannot cast %v as float64", cacheValue[0]) + continue + } + certBase64, ok := cacheValue[1].(string) + if !ok { + logger.Warnf("cannot cast %v as string", cacheValue[1]) + continue + } + certValue := &certCacheValue{ts, certBase64} + status := extractOCSPCacheResponseValueWithoutSubject(certValue) if !isValidOCSPStatus(status.code) { continue } cacheKey := encodeCertIDKey(k) - ocspResponseCache[*cacheKey] = cacheValue + ocspResponseCache[*cacheKey] = certValue } cacheUpdated = false } -func extractOCSPCacheResponseValueWithoutSubject(cacheValue []interface{}) *ocspStatus { +func extractOCSPCacheResponseValueWithoutSubject(cacheValue *certCacheValue) *ocspStatus { return extractOCSPCacheResponseValue(cacheValue, nil, nil) } -func extractOCSPCacheResponseValue(cacheValue []interface{}, subject, issuer *x509.Certificate) *ocspStatus { +func extractOCSPCacheResponseValue(certCacheValue *certCacheValue, subject, issuer *x509.Certificate) *ocspStatus { subjectName := "Unknown" if subject != nil { subjectName = subject.Subject.CommonName } curTime := time.Now() - if len(cacheValue) != 2 { + currentTime := float64(curTime.UTC().Unix()) + if currentTime-certCacheValue.ts >= cacheExpire { return &ocspStatus{ - code: ocspMissedCache, - err: fmt.Errorf("miss cache data. subject: %v", subjectName), + code: ocspCacheExpired, + err: fmt.Errorf("cache expired. current: %v, cache: %v", + time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(certCacheValue.ts), 0).UTC()), } } - if ts, ok := cacheValue[0].(float64); ok { - currentTime := float64(curTime.UTC().Unix()) - if currentTime-ts >= cacheExpire { - return &ocspStatus{ - code: ocspCacheExpired, - err: fmt.Errorf("cache expired. current: %v, cache: %v", - time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(ts), 0).UTC()), - } - } - } else { + var err error + var r *ocsp.Response + var b []byte + b, err = base64.StdEncoding.DecodeString(certCacheValue.certBase64) + if err != nil { return &ocspStatus{ code: ocspFailedDecodeResponse, - err: errors.New("the first cache element is not float64"), + err: fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err), } } - var err error - var r *ocsp.Response - if s, ok := cacheValue[1].(string); ok { - var b []byte - b, err = base64.StdEncoding.DecodeString(s) - if err != nil { - return &ocspStatus{ - code: ocspFailedDecodeResponse, - err: fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err), - } - } - // check the revocation status here - r, err = ocsp.ParseResponse(b, issuer) - if err != nil { - logger.Warnf("the second cache element is not a valid OCSP Response. Ignored. subject: %v\n", subjectName) - return &ocspStatus{ - code: ocspFailedParseResponse, - err: fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err), - } - } - } else { + // check the revocation status here + r, err = ocsp.ParseResponse(b, issuer) + if err != nil { + logger.Warnf("the second cache element is not a valid OCSP Response. Ignored. subject: %v\n", subjectName) return &ocspStatus{ - code: ocspFailedDecodeResponse, - err: errors.New("the second cache element is not string"), + code: ocspFailedParseResponse, + err: fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err), } - } return validateOCSP(r) } @@ -939,7 +954,7 @@ func writeOCSPCacheFile() { buf := make(map[string][]interface{}) for k, v := range ocspResponseCache { cacheKeyInBase64 := decodeCertIDKey(&k) - buf[cacheKeyInBase64] = v + buf[cacheKeyInBase64] = []interface{}{v.ts, v.certBase64} } j, err := json.Marshal(buf) diff --git a/ocsp_test.go b/ocsp_test.go index c89a7a6ee..1204b0635 100644 --- a/ocsp_test.go +++ b/ocsp_test.go @@ -42,7 +42,7 @@ func TestOCSP(t *testing.T) { for _, tgt := range targetURL { _ = os.Setenv(cacheServerEnabledEnv, enabled) _ = os.Remove(cacheFileName) // clear cache file - ocspResponseCache = make(map[certIDKey][]interface{}) + ocspResponseCache = make(map[certIDKey]*certCacheValue) for _, tr := range transports { t.Run(fmt.Sprintf("%v_%v", tgt, enabled), func(t *testing.T) { c := &http.Client{ @@ -188,7 +188,7 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { } b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE")) currentTime := float64(time.Now().UTC().Unix()) - ocspResponseCache[dummyKey0] = []interface{}{currentTime, b64Key} + ocspResponseCache[dummyKey0] = &certCacheValue{currentTime, b64Key} subject := &x509.Certificate{} issuer := &x509.Certificate{} ost := checkOCSPResponseCache(&dummyKey, subject, issuer) @@ -196,13 +196,13 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code) } // old timestamp - ocspResponseCache[dummyKey] = []interface{}{float64(1395054952), b64Key} + ocspResponseCache[dummyKey] = &certCacheValue{float64(1395054952), b64Key} ost = checkOCSPResponseCache(&dummyKey, subject, issuer) if ost.code != ocspCacheExpired { t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code) } // future timestamp - ocspResponseCache[dummyKey] = []interface{}{float64(1805054952), b64Key} + ocspResponseCache[dummyKey] = &certCacheValue{float64(1805054952), b64Key} ost = checkOCSPResponseCache(&dummyKey, subject, issuer) if ost.code != ocspFailedParseResponse { t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code) @@ -216,29 +216,17 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) { "koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + // pragma: allowlist secret "Jb4xI/OTXIrRA5TsT4cCioE" // issuer is not a true issuer certificate - ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse} + ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse} ost = checkOCSPResponseCache(&dummyKey, subject, issuer) if ost.code != ocspFailedParseResponse { t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedParseResponse, ost.code) } // invalid validity - ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse} + ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse} ost = checkOCSPResponseCache(&dummyKey, subject, nil) if ost.code != ocspInvalidValidity { t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code) } - // wrong timestamp type - ocspResponseCache[dummyKey] = []interface{}{uint32(currentTime - 1000), 123456} - ost = checkOCSPResponseCache(&dummyKey, subject, issuer) - if ost.code != ocspFailedDecodeResponse { - t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code) - } - // wrong value type - ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), 123456} - ost = checkOCSPResponseCache(&dummyKey, subject, issuer) - if ost.code != ocspFailedDecodeResponse { - t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code) - } } func TestUnitValidateOCSP(t *testing.T) {