From f7b595687ff1dce5718d125ea07c4e5191199841 Mon Sep 17 00:00:00 2001 From: Timur_Akhmadiev Date: Thu, 2 Nov 2023 17:06:22 +0400 Subject: [PATCH 1/4] fix: remove unsafe interface conversion --- authexternalbrowser.go | 6 +++++- dsn.go | 6 +++++- gcs_storage_client.go | 17 ++++++++++++++--- statement.go | 39 ++++++++++++++++++++++++++++++++++----- 4 files changed, 58 insertions(+), 10 deletions(-) diff --git a/authexternalbrowser.go b/authexternalbrowser.go index a8d966cef..ddb956a5d 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -219,7 +219,11 @@ func doAuthenticateByExternalBrowser( } defer l.Close() - callbackPort := l.Addr().(*net.TCPAddr).Port + addr, ok := l.Addr().(*net.TCPAddr) + if !ok { + return authenticateByExternalBrowserResult{nil, nil, fmt.Errorf("interface convertion. expected type *net.TCPAddr but got %T", l.Addr())} + } + callbackPort := addr.Port idpURL, proofKey, err := getIdpURLProofKey( ctx, sr, authenticator, application, account, callbackPort) if err != nil { diff --git a/dsn.go b/dsn.go index 1ccebba05..341e8e2b4 100644 --- a/dsn.go +++ b/dsn.go @@ -869,5 +869,9 @@ func parsePrivateKeyFromFile(path string) (*rsa.PrivateKey, error) { if err != nil { return nil, err } - return privateKey.(*rsa.PrivateKey), nil + pk, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type *rsa.PrivateKey, but got %T", privateKey) + } + return pk, nil } diff --git a/gcs_storage_client.go b/gcs_storage_client.go index 9affaa78f..91bfee1e0 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -54,7 +54,10 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin if err != nil { return nil, err } - accessToken := meta.client.(string) + accessToken, ok := meta.client.(string) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type string but got %T", meta.client) + } gcsHeaders := map[string]string{ "Authorization": "Bearer " + accessToken, } @@ -145,7 +148,11 @@ func (util *snowflakeGcsClient) uploadFile( if err != nil { return err } - accessToken = meta.client.(string) + var ok bool + accessToken, ok = meta.client.(string) + if !ok { + return fmt.Errorf("interface convertion. expected type string but got %T", meta.client) + } } var contentEncoding string @@ -271,7 +278,11 @@ func (util *snowflakeGcsClient) nativeDownloadFile( if err != nil { return err } - accessToken = meta.client.(string) + var ok bool + accessToken, ok = meta.client.(string) + if !ok { + return fmt.Errorf("interface convertion. expected type string but got %T", meta.client) + } if accessToken != "" { gcsHeaders["Authorization"] = "Bearer " + accessToken } diff --git a/statement.go b/statement.go index 70d4479a7..9ebf941bd 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ package gosnowflake import ( "context" "database/sql/driver" + "fmt" ) // SnowflakeStmt represents the prepared statement in driver. @@ -33,28 +34,56 @@ func (stmt *snowflakeStmt) NumInput() int { func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") result, err := stmt.sc.ExecContext(ctx, stmt.query, args) - stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + if err != nil { + return nil, err + } + r, ok := result.(SnowflakeResult) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) + } + stmt.lastQueryID = r.GetQueryID() return result, err } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext") rows, err := stmt.sc.QueryContext(ctx, stmt.query, args) - stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() - return rows, err + if err != nil { + return nil, err + } + r, ok := rows.(SnowflakeRows) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) + } + stmt.lastQueryID = r.GetQueryID() + return rows, nil } func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") result, err := stmt.sc.Exec(stmt.query, args) - stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + if err != nil { + return nil, err + } + r, ok := result.(SnowflakeResult) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) + } + stmt.lastQueryID = r.GetQueryID() return result, err } func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query") rows, err := stmt.sc.Query(stmt.query, args) - stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() + if err != nil { + return nil, err + } + r, ok := rows.(SnowflakeRows) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) + } + stmt.lastQueryID = r.GetQueryID() return rows, err } From 89ceef3d0dd589af6a7fed13990b93c64e44057a Mon Sep 17 00:00:00 2001 From: Timur_Akhmadiev Date: Tue, 7 Nov 2023 16:58:23 +0400 Subject: [PATCH 2/4] test: add tests for gcs storage client and dns package --- dsn_test.go | 31 +++++++++++++++++ gcs_storage_client_test.go | 70 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/dsn_test.go b/dsn_test.go index 9af66c90e..e133d8e8b 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -3,6 +3,8 @@ package gosnowflake import ( + "crypto/ecdsa" + "crypto/elliptic" cr "crypto/rand" "crypto/rsa" "crypto/x509" @@ -15,6 +17,8 @@ import ( "strings" "testing" "time" + + "github.com/aws/smithy-go/rand" ) type tcParseDSN struct { @@ -1358,6 +1362,33 @@ func TestParsePrivateKeyFromFileIncorrectData(t *testing.T) { } } +func TestParsePrivateKeyFromFileNotRSAPrivateKey(t *testing.T) { + // Generate an ECDSA private key for testing + ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate ECDSA private key: %v", err) + } + + ecdsaPrivateKeyBytes, err := x509.MarshalECPrivateKey(ecdsaPrivateKey) + if err != nil { + t.Fatalf("failed to marshal ECDSA private key: %v", err) + } + pemBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: ecdsaPrivateKeyBytes, + } + pemData := pem.EncodeToMemory(pemBlock) + + // Write the PEM data to a temporary file + pemFile := createTmpFile("ecdsaKey.pem", pemData) + + // Attempt to parse the private key + _, err = parsePrivateKeyFromFile(pemFile) + if err == nil { + t.Error("expected an error when trying to parse an ECDSA private key as RSA") + } +} + func TestParsePrivateKeyFromFile(t *testing.T) { generatedKey, _ := rsa.GenerateKey(cr.Reader, 1024) pemKey, _ := x509.MarshalPKCS8PrivateKey(generatedKey) diff --git a/gcs_storage_client_test.go b/gcs_storage_client_test.go index 75b5f50f7..04d3ea67f 100644 --- a/gcs_storage_client_test.go +++ b/gcs_storage_client_test.go @@ -810,6 +810,40 @@ func TestGetFileHeaderEncryptionData(t *testing.T) { } } +func TestGetFileHeaderEncryptionDataInterfaceConversionError(t *testing.T) { + mockEncDataResp := "{\"EncryptionMode\":\"FullBlob\",\"WrappedContentKey\": {\"KeyId\":\"symmKey1\",\"EncryptedKey\":\"testencryptedkey12345678910==\",\"Algorithm\":\"AES_CBC_256\"},\"EncryptionAgent\": {\"Protocol\":\"1.0\",\"EncryptionAlgorithm\":\"AES_CBC_256\"},\"ContentEncryptionIV\":\"testIVkey12345678910==\",\"KeyWrappingMetadata\":{\"EncryptionLibrary\":\"Java 5.3.0\"}}" + mockMatDesc := "{\"queryid\":\"01abc874-0406-1bf0-0000-53b10668e056\",\"smkid\":\"92019681909886\",\"key\":\"128\"}" + info := execResponseStageInfo{ + Location: "gcs/teststage/users/34/", + LocationType: "GCS", + Creds: execResponseCredentials{ + GcsAccessToken: "test-token-124456577", + }, + } + meta := fileMetadata{ + client: 1, + stageInfo: &info, + mockGcsClient: &clientMock{ + DoFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{ + "X-Goog-Meta-Encryptiondata": []string{mockEncDataResp}, + "Content-Length": []string{"4256"}, + "X-Goog-Meta-Sfc-Digest": []string{"123456789abcdef"}, + "X-Goog-Meta-Matdesc": []string{mockMatDesc}, + }, + }, nil + }, + }, + } + _, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt") + if err == nil { + t.Error("should have raised an error") + } +} + func TestUploadFileToGcsNoStatus(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", @@ -961,3 +995,39 @@ func TestDownloadFileWithBadRequest(t *testing.T) { renewPresignedURL, downloadMeta.resStatus) } } + +func Test_snowflakeGcsClient_uploadFile(t *testing.T) { + info := execResponseStageInfo{ + Location: "gcs/teststage/users/34/", + LocationType: "GCS", + Creds: execResponseCredentials{ + GcsAccessToken: "test-token-124456577", + }, + } + meta := fileMetadata{ + client: 1, + stageInfo: &info, + } + err := new(snowflakeGcsClient).uploadFile("somedata", &meta, nil, 1, 1) + if err == nil { + t.Error("should have raised an error") + } +} + +func Test_snowflakeGcsClient_nativeDownloadFile(t *testing.T) { + info := execResponseStageInfo{ + Location: "gcs/teststage/users/34/", + LocationType: "GCS", + Creds: execResponseCredentials{ + GcsAccessToken: "test-token-124456577", + }, + } + meta := fileMetadata{ + client: 1, + stageInfo: &info, + } + err := new(snowflakeGcsClient).nativeDownloadFile(&meta, "dummy data", 1) + if err == nil { + t.Error("should have raised an error") + } +} From e1df1b421faadbc80d4ad09e9cce78b6310b9702 Mon Sep 17 00:00:00 2001 From: Timur_Akhmadiev Date: Mon, 13 Nov 2023 10:23:09 +0400 Subject: [PATCH 3/4] test: add tests for statement and tcp listener --- authexternalbrowser.go | 19 ++-- authexternalbrowser_test.go | 13 +++ statement_test.go | 176 ++++++++++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 9 deletions(-) diff --git a/authexternalbrowser.go b/authexternalbrowser.go index ddb956a5d..373173f5b 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -56,13 +56,18 @@ func buildResponse(application string) bytes.Buffer { // This opens a socket that listens on all available unicast // and any anycast IP addresses locally. By specifying "0", we are // able to bind to a free port. -func bindToPort() (net.Listener, error) { +func createLocalTCPListener() (*net.TCPListener, error) { l, err := net.Listen("tcp", "localhost:0") if err != nil { - logger.Infof("unable to bind to a port on localhost, err: %v", err) return nil, err } - return l, nil + + tcpListener, ok := l.(*net.TCPListener) + if !ok { + return nil, fmt.Errorf("failed to assert type as *net.TCPListener") + } + + return tcpListener, nil } // Opens a browser window (or new tab) with the configured IDP Url. @@ -213,17 +218,13 @@ func doAuthenticateByExternalBrowser( user string, password string, ) authenticateByExternalBrowserResult { - l, err := bindToPort() + l, err := createLocalTCPListener() if err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } defer l.Close() - addr, ok := l.Addr().(*net.TCPAddr) - if !ok { - return authenticateByExternalBrowserResult{nil, nil, fmt.Errorf("interface convertion. expected type *net.TCPAddr but got %T", l.Addr())} - } - callbackPort := addr.Port + callbackPort := l.Addr().(*net.TCPAddr).Port idpURL, proofKey, err := getIdpURLProofKey( ctx, sr, authenticator, application, account, callbackPort) if err != nil { diff --git a/authexternalbrowser_test.go b/authexternalbrowser_test.go index a6650dc78..ea1a19ac9 100644 --- a/authexternalbrowser_test.go +++ b/authexternalbrowser_test.go @@ -133,3 +133,16 @@ func TestAuthenticationTimeout(t *testing.T) { t.Fatal("should have timed out") } } + +func Test_createLocalTCPListener(t *testing.T) { + listener, err := createLocalTCPListener() + if err != nil { + t.Fatalf("createLocalTCPListener() failed: %v", err) + } + if listener == nil { + t.Fatal("createLocalTCPListener() returned nil listener") + } + + // Close the listener after the test. + defer listener.Close() +} diff --git a/statement_test.go b/statement_test.go index 68f1e7a7b..5d879f055 100644 --- a/statement_test.go +++ b/statement_test.go @@ -248,6 +248,40 @@ func TestStmtExec(t *testing.T) { } } +func TestStmtExec_Error(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + // Create a test table + if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil { + t.Fatalf("failed to create table: %v", err) + } + + // Attempt to execute an invalid statement + if err := conn.Raw(func(x interface{}) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (?, ?)") + if err != nil { + t.Fatalf("failed to prepare statement: %v", err) + } + + // Intentionally passing a string instead of an integer to cause an error + _, err = stmt.(*snowflakeStmt).Exec([]driver.Value{"invalid_data", 2}) + if err == nil { + t.Errorf("expected an error, but got none") + } + + return nil + }); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Drop the test table + if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil { + t.Fatalf("failed to drop table: %v", err) + } +} + func getStatusSuccessButInvalidJSONfunc(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, @@ -351,6 +385,76 @@ func TestStatementQueryIdForQueries(t *testing.T) { } } +func TestStatementQuery(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + testcases := []struct { + name string + query string + f func(stmt driver.Stmt) (driver.Rows, error) + wantErr bool + }{ + { + "validQuery", + "SELECT 1", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + false, + }, + { + "validQueryContext", + "SELECT 1", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + false, + }, + { + "invalidQuery", + "SELECT * FROM non_existing_table", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + true, + }, + { + "invalidQueryContext", + "SELECT * FROM non_existing_table", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + true, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) + if err != nil { + if tc.wantErr { + return nil // expected error + } + t.Fatal(err) + } + + _, err = tc.f(stmt) + if (err != nil) != tc.wantErr { + t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestStatementQueryIdForExecs(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { @@ -414,3 +518,75 @@ func TestStatementQueryIdForExecs(t *testing.T) { } }) } + +func TestStatementQueryExecs(t *testing.T) { + ctx := context.Background() + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") + defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") + + testcases := []struct { + name string + query string + f func(stmt driver.Stmt) (driver.Result, error) + wantErr bool + }{ + { + "validExec", + "INSERT INTO TestStatementQueryIdForExecs VALUES (1)", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + false, + }, + { + "validExecContext", + "INSERT INTO TestStatementQueryIdForExecs VALUES (1)", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + false, + }, + { + "invalidExec", + "INSERT INTO TestStatementQueryIdForExecs VALUES (NULL)", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + true, + }, + { + "invalidExecContext", + "INSERT INTO TestStatementQueryIdForExecs VALUES (NULL)", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + true, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) + if err != nil { + if tc.wantErr { + return nil // expected error + } + t.Fatal(err) + } + + _, err = tc.f(stmt) + if (err != nil) != tc.wantErr { + t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} From 12225f2f84408a89e7746d9c0b7291ee57d39e86 Mon Sep 17 00:00:00 2001 From: Timur_Akhmadiev Date: Mon, 13 Nov 2023 11:42:59 +0400 Subject: [PATCH 4/4] test: fix TestStatementQueryExecs test --- statement_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/statement_test.go b/statement_test.go index 5d879f055..f1b6090ee 100644 --- a/statement_test.go +++ b/statement_test.go @@ -522,8 +522,8 @@ func TestStatementQueryIdForExecs(t *testing.T) { func TestStatementQueryExecs(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") - defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") + dbt.mustExec("CREATE TABLE TestStatementQueryExecs (v INTEGER)") + defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementForExecs") testcases := []struct { name string @@ -533,7 +533,7 @@ func TestStatementQueryExecs(t *testing.T) { }{ { "validExec", - "INSERT INTO TestStatementQueryIdForExecs VALUES (1)", + "INSERT INTO TestStatementQueryExecs VALUES (1)", func(stmt driver.Stmt) (driver.Result, error) { return stmt.Exec(nil) }, @@ -541,7 +541,7 @@ func TestStatementQueryExecs(t *testing.T) { }, { "validExecContext", - "INSERT INTO TestStatementQueryIdForExecs VALUES (1)", + "INSERT INTO TestStatementQueryExecs VALUES (1)", func(stmt driver.Stmt) (driver.Result, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) }, @@ -549,7 +549,7 @@ func TestStatementQueryExecs(t *testing.T) { }, { "invalidExec", - "INSERT INTO TestStatementQueryIdForExecs VALUES (NULL)", + "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", func(stmt driver.Stmt) (driver.Result, error) { return stmt.Exec(nil) }, @@ -557,7 +557,7 @@ func TestStatementQueryExecs(t *testing.T) { }, { "invalidExecContext", - "INSERT INTO TestStatementQueryIdForExecs VALUES (NULL)", + "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", func(stmt driver.Stmt) (driver.Result, error) { return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) },