Skip to content

Commit

Permalink
SNOW-961482: Replace OCSP cache with more structured version
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Nov 7, 2023
1 parent 8445dca commit a2e1ee0
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 69 deletions.
2 changes: 1 addition & 1 deletion driver_ocsp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func deleteOCSPCacheFile() {
func deleteOCSPCacheAll() {
ocspResponseCacheLock.Lock()
defer ocspResponseCacheLock.Unlock()
ocspResponseCache = make(map[certIDKey][]interface{})
ocspResponseCache = make(map[certIDKey]*certCacheValue)
}

func cleanup() {
Expand Down
1 change: 1 addition & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func setup() (string, error) {
return "", fmt.Errorf("failed to create schema. %v", err)
}
createDSN("UTC")
GetLogger().SetLogLevel("debug")
return orgSchemaname, nil
}

Expand Down
115 changes: 65 additions & 50 deletions ocsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,13 @@ type certIDKey struct {
SerialNumber string
}

type certCacheValue struct {
ts float64
certBase64 string
}

var (
ocspResponseCache map[certIDKey][]interface{}
ocspResponseCache map[certIDKey]*certCacheValue
ocspResponseCacheLock *sync.RWMutex
)

Expand Down Expand Up @@ -270,8 +275,14 @@ func checkOCSPResponseCache(encodedCertID *certIDKey, subject, issuer *x509.Cert
return &ocspStatus{code: ocspNoServer}
}
ocspResponseCacheLock.RLock()
gotValueFromCache := ocspResponseCache[*encodedCertID]
ocspResponseCacheLock.RUnlock()
defer ocspResponseCacheLock.RUnlock()
gotValueFromCache, ok := ocspResponseCache[*encodedCertID]
if !ok {
return &ocspStatus{
code: ocspMissedCache,
err: fmt.Errorf("miss cache data. subject: %v", subject),
}
}

status := extractOCSPCacheResponseValue(gotValueFromCache, subject, issuer)
if !isValidOCSPStatus(status.code) {
Expand Down Expand Up @@ -354,7 +365,7 @@ func checkOCSPCacheServer(
req requestFunc,
ocspServerHost *url.URL,
totalTimeout time.Duration) (
cacheContent *map[string][]interface{},
cacheContent *map[string]*certCacheValue,
ocspS *ocspStatus) {
var respd map[string][]interface{}
headers := make(map[string]string)
Expand Down Expand Up @@ -388,7 +399,21 @@ func checkOCSPCacheServer(
}
}
}
return &respd, &ocspStatus{
buf := make(map[string]*certCacheValue)
for key, value := range respd {
ts, ok := value[0].(float64)
if !ok {
logger.Warnf("cannot cast %v to float64", value[0])
continue
}
certBase64, ok := value[1].(string)
if !ok {
logger.Warnf("cannot cast %v to string", value[1])
continue
}
buf[key] = &certCacheValue{ts, certBase64}
}
return &buf, &ocspStatus{
code: ocspSuccess,
}
}
Expand Down Expand Up @@ -608,7 +633,7 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate)
if !isValidOCSPStatus(ret.code) {
return ret // return invalid
}
v := []interface{}{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)}
v := &certCacheValue{float64(time.Now().UTC().Unix()), base64.StdEncoding.EncodeToString(ocspResBytes)}
ocspResponseCacheLock.Lock()
ocspResponseCache[*encodedCertID] = v
cacheUpdated = true
Expand Down Expand Up @@ -804,7 +829,7 @@ func initOCSPCache() {
if strings.EqualFold(os.Getenv(cacheServerEnabledEnv), "false") {
return
}
ocspResponseCache = make(map[certIDKey][]interface{})
ocspResponseCache = make(map[certIDKey]*certCacheValue)
ocspResponseCacheLock = &sync.RWMutex{}

logger.Infof("reading OCSP Response cache file. %v\n", cacheFileName)
Expand All @@ -828,74 +853,64 @@ func initOCSPCache() {
}
}
for k, cacheValue := range buf {
status := extractOCSPCacheResponseValueWithoutSubject(cacheValue)
ts, ok := cacheValue[0].(float64)
if !ok {
logger.Warnf("cannot cast %v as float64", cacheValue[0])
continue
}
certBase64, ok := cacheValue[1].(string)
if !ok {
logger.Warnf("cannot cast %v as string", cacheValue[1])
continue
}
certValue := &certCacheValue{ts, certBase64}
status := extractOCSPCacheResponseValueWithoutSubject(certValue)
if !isValidOCSPStatus(status.code) {
continue
}
cacheKey := encodeCertIDKey(k)
ocspResponseCache[*cacheKey] = cacheValue
ocspResponseCache[*cacheKey] = certValue

}
cacheUpdated = false
}
func extractOCSPCacheResponseValueWithoutSubject(cacheValue []interface{}) *ocspStatus {
func extractOCSPCacheResponseValueWithoutSubject(cacheValue *certCacheValue) *ocspStatus {
return extractOCSPCacheResponseValue(cacheValue, nil, nil)
}

func extractOCSPCacheResponseValue(cacheValue []interface{}, subject, issuer *x509.Certificate) *ocspStatus {
func extractOCSPCacheResponseValue(certCacheValue *certCacheValue, subject, issuer *x509.Certificate) *ocspStatus {
subjectName := "Unknown"
if subject != nil {
subjectName = subject.Subject.CommonName
}

curTime := time.Now()
if len(cacheValue) != 2 {
currentTime := float64(curTime.UTC().Unix())
if currentTime-certCacheValue.ts >= cacheExpire {
return &ocspStatus{
code: ocspMissedCache,
err: fmt.Errorf("miss cache data. subject: %v", subjectName),
code: ocspCacheExpired,
err: fmt.Errorf("cache expired. current: %v, cache: %v",
time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(certCacheValue.ts), 0).UTC()),
}
}
if ts, ok := cacheValue[0].(float64); ok {
currentTime := float64(curTime.UTC().Unix())
if currentTime-ts >= cacheExpire {
return &ocspStatus{
code: ocspCacheExpired,
err: fmt.Errorf("cache expired. current: %v, cache: %v",
time.Unix(int64(currentTime), 0).UTC(), time.Unix(int64(ts), 0).UTC()),
}
}
} else {
var err error
var r *ocsp.Response
var b []byte
b, err = base64.StdEncoding.DecodeString(certCacheValue.certBase64)
if err != nil {
return &ocspStatus{
code: ocspFailedDecodeResponse,
err: errors.New("the first cache element is not float64"),
err: fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err),
}
}
var err error
var r *ocsp.Response
if s, ok := cacheValue[1].(string); ok {
var b []byte
b, err = base64.StdEncoding.DecodeString(s)
if err != nil {
return &ocspStatus{
code: ocspFailedDecodeResponse,
err: fmt.Errorf("failed to decode OCSP Response value in a cache. subject: %v, err: %v", subjectName, err),
}
}
// check the revocation status here
r, err = ocsp.ParseResponse(b, issuer)
if err != nil {
logger.Warnf("the second cache element is not a valid OCSP Response. Ignored. subject: %v\n", subjectName)
return &ocspStatus{
code: ocspFailedParseResponse,
err: fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err),
}
}
} else {
// check the revocation status here
r, err = ocsp.ParseResponse(b, issuer)
if err != nil {
logger.Warnf("the second cache element is not a valid OCSP Response. Ignored. subject: %v\n", subjectName)
return &ocspStatus{
code: ocspFailedDecodeResponse,
err: errors.New("the second cache element is not string"),
code: ocspFailedParseResponse,
err: fmt.Errorf("failed to parse OCSP Respose. subject: %v, err: %v", subjectName, err),
}

}
return validateOCSP(r)
}
Expand Down Expand Up @@ -939,7 +954,7 @@ func writeOCSPCacheFile() {
buf := make(map[string][]interface{})
for k, v := range ocspResponseCache {
cacheKeyInBase64 := decodeCertIDKey(&k)
buf[cacheKeyInBase64] = v
buf[cacheKeyInBase64] = []interface{}{v.ts, v.certBase64}
}

j, err := json.Marshal(buf)
Expand Down
24 changes: 6 additions & 18 deletions ocsp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestOCSP(t *testing.T) {
for _, tgt := range targetURL {
_ = os.Setenv(cacheServerEnabledEnv, enabled)
_ = os.Remove(cacheFileName) // clear cache file
ocspResponseCache = make(map[certIDKey][]interface{})
ocspResponseCache = make(map[certIDKey]*certCacheValue)
for _, tr := range transports {
t.Run(fmt.Sprintf("%v_%v", tgt, enabled), func(t *testing.T) {
c := &http.Client{
Expand Down Expand Up @@ -188,21 +188,21 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) {
}
b64Key := base64.StdEncoding.EncodeToString([]byte("DUMMY_VALUE"))
currentTime := float64(time.Now().UTC().Unix())
ocspResponseCache[dummyKey0] = []interface{}{currentTime, b64Key}
ocspResponseCache[dummyKey0] = &certCacheValue{currentTime, b64Key}
subject := &x509.Certificate{}
issuer := &x509.Certificate{}
ost := checkOCSPResponseCache(&dummyKey, subject, issuer)
if ost.code != ocspMissedCache {
t.Fatalf("should have failed. expected: %v, got: %v", ocspMissedCache, ost.code)
}
// old timestamp
ocspResponseCache[dummyKey] = []interface{}{float64(1395054952), b64Key}
ocspResponseCache[dummyKey] = &certCacheValue{float64(1395054952), b64Key}
ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
if ost.code != ocspCacheExpired {
t.Fatalf("should have failed. expected: %v, got: %v", ocspCacheExpired, ost.code)
}
// future timestamp
ocspResponseCache[dummyKey] = []interface{}{float64(1805054952), b64Key}
ocspResponseCache[dummyKey] = &certCacheValue{float64(1805054952), b64Key}
ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
if ost.code != ocspFailedParseResponse {
t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code)
Expand All @@ -216,29 +216,17 @@ func TestUnitCheckOCSPResponseCache(t *testing.T) {
"koRzw/UU7zKsqiTB0ZN/rgJp+MocTdqQSGKvbZyR8d4u8eNQqi1x4Pk3yO/pftANFaJKGB+JPgKS3PQAqJaXcipNcEfqtl7y4PO6kqA" + // pragma: allowlist secret
"Jb4xI/OTXIrRA5TsT4cCioE"
// issuer is not a true issuer certificate
ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse}
ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse}
ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
if ost.code != ocspFailedParseResponse {
t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedParseResponse, ost.code)
}
// invalid validity
ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), actualOcspResponse}
ocspResponseCache[dummyKey] = &certCacheValue{float64(currentTime - 1000), actualOcspResponse}
ost = checkOCSPResponseCache(&dummyKey, subject, nil)
if ost.code != ocspInvalidValidity {
t.Fatalf("should have failed. expected: %v, got: %v", ocspInvalidValidity, ost.code)
}
// wrong timestamp type
ocspResponseCache[dummyKey] = []interface{}{uint32(currentTime - 1000), 123456}
ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
if ost.code != ocspFailedDecodeResponse {
t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code)
}
// wrong value type
ocspResponseCache[dummyKey] = []interface{}{float64(currentTime - 1000), 123456}
ost = checkOCSPResponseCache(&dummyKey, subject, issuer)
if ost.code != ocspFailedDecodeResponse {
t.Fatalf("should have failed. expected: %v, got: %v", ocspFailedDecodeResponse, ost.code)
}
}

func TestUnitValidateOCSP(t *testing.T) {
Expand Down

0 comments on commit a2e1ee0

Please sign in to comment.