diff --git a/authexternalbrowser.go b/authexternalbrowser.go index a8d966cef..373173f5b 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -56,13 +56,18 @@ func buildResponse(application string) bytes.Buffer { // This opens a socket that listens on all available unicast // and any anycast IP addresses locally. By specifying "0", we are // able to bind to a free port. -func bindToPort() (net.Listener, error) { +func createLocalTCPListener() (*net.TCPListener, error) { l, err := net.Listen("tcp", "localhost:0") if err != nil { - logger.Infof("unable to bind to a port on localhost, err: %v", err) return nil, err } - return l, nil + + tcpListener, ok := l.(*net.TCPListener) + if !ok { + return nil, fmt.Errorf("failed to assert type as *net.TCPListener") + } + + return tcpListener, nil } // Opens a browser window (or new tab) with the configured IDP Url. @@ -213,7 +218,7 @@ func doAuthenticateByExternalBrowser( user string, password string, ) authenticateByExternalBrowserResult { - l, err := bindToPort() + l, err := createLocalTCPListener() if err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } diff --git a/authexternalbrowser_test.go b/authexternalbrowser_test.go index a6650dc78..ea1a19ac9 100644 --- a/authexternalbrowser_test.go +++ b/authexternalbrowser_test.go @@ -133,3 +133,16 @@ func TestAuthenticationTimeout(t *testing.T) { t.Fatal("should have timed out") } } + +func Test_createLocalTCPListener(t *testing.T) { + listener, err := createLocalTCPListener() + if err != nil { + t.Fatalf("createLocalTCPListener() failed: %v", err) + } + if listener == nil { + t.Fatal("createLocalTCPListener() returned nil listener") + } + + // Close the listener after the test. + defer listener.Close() +} diff --git a/dsn.go b/dsn.go index 1ccebba05..341e8e2b4 100644 --- a/dsn.go +++ b/dsn.go @@ -869,5 +869,9 @@ func parsePrivateKeyFromFile(path string) (*rsa.PrivateKey, error) { if err != nil { return nil, err } - return privateKey.(*rsa.PrivateKey), nil + pk, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type *rsa.PrivateKey, but got %T", privateKey) + } + return pk, nil } diff --git a/dsn_test.go b/dsn_test.go index 9af66c90e..e133d8e8b 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -3,6 +3,8 @@ package gosnowflake import ( + "crypto/ecdsa" + "crypto/elliptic" cr "crypto/rand" "crypto/rsa" "crypto/x509" @@ -15,6 +17,8 @@ import ( "strings" "testing" "time" + + "github.com/aws/smithy-go/rand" ) type tcParseDSN struct { @@ -1358,6 +1362,33 @@ func TestParsePrivateKeyFromFileIncorrectData(t *testing.T) { } } +func TestParsePrivateKeyFromFileNotRSAPrivateKey(t *testing.T) { + // Generate an ECDSA private key for testing + ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate ECDSA private key: %v", err) + } + + ecdsaPrivateKeyBytes, err := x509.MarshalECPrivateKey(ecdsaPrivateKey) + if err != nil { + t.Fatalf("failed to marshal ECDSA private key: %v", err) + } + pemBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: ecdsaPrivateKeyBytes, + } + pemData := pem.EncodeToMemory(pemBlock) + + // Write the PEM data to a temporary file + pemFile := createTmpFile("ecdsaKey.pem", pemData) + + // Attempt to parse the private key + _, err = parsePrivateKeyFromFile(pemFile) + if err == nil { + t.Error("expected an error when trying to parse an ECDSA private key as RSA") + } +} + func TestParsePrivateKeyFromFile(t *testing.T) { generatedKey, _ := rsa.GenerateKey(cr.Reader, 1024) pemKey, _ := x509.MarshalPKCS8PrivateKey(generatedKey) diff --git a/gcs_storage_client.go b/gcs_storage_client.go index 9affaa78f..91bfee1e0 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -54,7 +54,10 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin if err != nil { return nil, err } - accessToken := meta.client.(string) + accessToken, ok := meta.client.(string) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type string but got %T", meta.client) + } gcsHeaders := map[string]string{ "Authorization": "Bearer " + accessToken, } @@ -145,7 +148,11 @@ func (util *snowflakeGcsClient) uploadFile( if err != nil { return err } - accessToken = meta.client.(string) + var ok bool + accessToken, ok = meta.client.(string) + if !ok { + return fmt.Errorf("interface convertion. expected type string but got %T", meta.client) + } } var contentEncoding string @@ -271,7 +278,11 @@ func (util *snowflakeGcsClient) nativeDownloadFile( if err != nil { return err } - accessToken = meta.client.(string) + var ok bool + accessToken, ok = meta.client.(string) + if !ok { + return fmt.Errorf("interface convertion. expected type string but got %T", meta.client) + } if accessToken != "" { gcsHeaders["Authorization"] = "Bearer " + accessToken } diff --git a/gcs_storage_client_test.go b/gcs_storage_client_test.go index 75b5f50f7..04d3ea67f 100644 --- a/gcs_storage_client_test.go +++ b/gcs_storage_client_test.go @@ -810,6 +810,40 @@ func TestGetFileHeaderEncryptionData(t *testing.T) { } } +func TestGetFileHeaderEncryptionDataInterfaceConversionError(t *testing.T) { + mockEncDataResp := "{\"EncryptionMode\":\"FullBlob\",\"WrappedContentKey\": {\"KeyId\":\"symmKey1\",\"EncryptedKey\":\"testencryptedkey12345678910==\",\"Algorithm\":\"AES_CBC_256\"},\"EncryptionAgent\": {\"Protocol\":\"1.0\",\"EncryptionAlgorithm\":\"AES_CBC_256\"},\"ContentEncryptionIV\":\"testIVkey12345678910==\",\"KeyWrappingMetadata\":{\"EncryptionLibrary\":\"Java 5.3.0\"}}" + mockMatDesc := "{\"queryid\":\"01abc874-0406-1bf0-0000-53b10668e056\",\"smkid\":\"92019681909886\",\"key\":\"128\"}" + info := execResponseStageInfo{ + Location: "gcs/teststage/users/34/", + LocationType: "GCS", + Creds: execResponseCredentials{ + GcsAccessToken: "test-token-124456577", + }, + } + meta := fileMetadata{ + client: 1, + stageInfo: &info, + mockGcsClient: &clientMock{ + DoFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{ + "X-Goog-Meta-Encryptiondata": []string{mockEncDataResp}, + "Content-Length": []string{"4256"}, + "X-Goog-Meta-Sfc-Digest": []string{"123456789abcdef"}, + "X-Goog-Meta-Matdesc": []string{mockMatDesc}, + }, + }, nil + }, + }, + } + _, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt") + if err == nil { + t.Error("should have raised an error") + } +} + func TestUploadFileToGcsNoStatus(t *testing.T) { info := execResponseStageInfo{ Location: "gcs-blob/storage/users/456/", @@ -961,3 +995,39 @@ func TestDownloadFileWithBadRequest(t *testing.T) { renewPresignedURL, downloadMeta.resStatus) } } + +func Test_snowflakeGcsClient_uploadFile(t *testing.T) { + info := execResponseStageInfo{ + Location: "gcs/teststage/users/34/", + LocationType: "GCS", + Creds: execResponseCredentials{ + GcsAccessToken: "test-token-124456577", + }, + } + meta := fileMetadata{ + client: 1, + stageInfo: &info, + } + err := new(snowflakeGcsClient).uploadFile("somedata", &meta, nil, 1, 1) + if err == nil { + t.Error("should have raised an error") + } +} + +func Test_snowflakeGcsClient_nativeDownloadFile(t *testing.T) { + info := execResponseStageInfo{ + Location: "gcs/teststage/users/34/", + LocationType: "GCS", + Creds: execResponseCredentials{ + GcsAccessToken: "test-token-124456577", + }, + } + meta := fileMetadata{ + client: 1, + stageInfo: &info, + } + err := new(snowflakeGcsClient).nativeDownloadFile(&meta, "dummy data", 1) + if err == nil { + t.Error("should have raised an error") + } +} diff --git a/statement.go b/statement.go index 70d4479a7..9ebf941bd 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ package gosnowflake import ( "context" "database/sql/driver" + "fmt" ) // SnowflakeStmt represents the prepared statement in driver. @@ -33,28 +34,56 @@ func (stmt *snowflakeStmt) NumInput() int { func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") result, err := stmt.sc.ExecContext(ctx, stmt.query, args) - stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + if err != nil { + return nil, err + } + r, ok := result.(SnowflakeResult) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) + } + stmt.lastQueryID = r.GetQueryID() return result, err } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext") rows, err := stmt.sc.QueryContext(ctx, stmt.query, args) - stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() - return rows, err + if err != nil { + return nil, err + } + r, ok := rows.(SnowflakeRows) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) + } + stmt.lastQueryID = r.GetQueryID() + return rows, nil } func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") result, err := stmt.sc.Exec(stmt.query, args) - stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + if err != nil { + return nil, err + } + r, ok := result.(SnowflakeResult) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result) + } + stmt.lastQueryID = r.GetQueryID() return result, err } func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query") rows, err := stmt.sc.Query(stmt.query, args) - stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() + if err != nil { + return nil, err + } + r, ok := rows.(SnowflakeRows) + if !ok { + return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows) + } + stmt.lastQueryID = r.GetQueryID() return rows, err } diff --git a/statement_test.go b/statement_test.go index 68f1e7a7b..f1b6090ee 100644 --- a/statement_test.go +++ b/statement_test.go @@ -248,6 +248,40 @@ func TestStmtExec(t *testing.T) { } } +func TestStmtExec_Error(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + // Create a test table + if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil { + t.Fatalf("failed to create table: %v", err) + } + + // Attempt to execute an invalid statement + if err := conn.Raw(func(x interface{}) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (?, ?)") + if err != nil { + t.Fatalf("failed to prepare statement: %v", err) + } + + // Intentionally passing a string instead of an integer to cause an error + _, err = stmt.(*snowflakeStmt).Exec([]driver.Value{"invalid_data", 2}) + if err == nil { + t.Errorf("expected an error, but got none") + } + + return nil + }); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Drop the test table + if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil { + t.Fatalf("failed to drop table: %v", err) + } +} + func getStatusSuccessButInvalidJSONfunc(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, @@ -351,6 +385,76 @@ func TestStatementQueryIdForQueries(t *testing.T) { } } +func TestStatementQuery(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + testcases := []struct { + name string + query string + f func(stmt driver.Stmt) (driver.Rows, error) + wantErr bool + }{ + { + "validQuery", + "SELECT 1", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + false, + }, + { + "validQueryContext", + "SELECT 1", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + false, + }, + { + "invalidQuery", + "SELECT * FROM non_existing_table", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + true, + }, + { + "invalidQueryContext", + "SELECT * FROM non_existing_table", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + true, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) + if err != nil { + if tc.wantErr { + return nil // expected error + } + t.Fatal(err) + } + + _, err = tc.f(stmt) + if (err != nil) != tc.wantErr { + t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestStatementQueryIdForExecs(t *testing.T) { ctx := context.Background() runDBTest(t, func(dbt *DBTest) { @@ -414,3 +518,75 @@ func TestStatementQueryIdForExecs(t *testing.T) { } }) } + +func TestStatementQueryExecs(t *testing.T) { + ctx := context.Background() + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE TestStatementQueryExecs (v INTEGER)") + defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementForExecs") + + testcases := []struct { + name string + query string + f func(stmt driver.Stmt) (driver.Result, error) + wantErr bool + }{ + { + "validExec", + "INSERT INTO TestStatementQueryExecs VALUES (1)", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + false, + }, + { + "validExecContext", + "INSERT INTO TestStatementQueryExecs VALUES (1)", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + false, + }, + { + "invalidExec", + "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + true, + }, + { + "invalidExecContext", + "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + true, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) + if err != nil { + if tc.wantErr { + return nil // expected error + } + t.Fatal(err) + } + + _, err = tc.f(stmt) + if (err != nil) != tc.wantErr { + t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +}