From 4fa9e437fa21e737254165579f88e73a46d78547 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Sat, 11 Dec 2021 11:38:46 -0700 Subject: [PATCH] add support for schema = --- sqlalchemy_trino/dialect.py | 76 +++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/sqlalchemy_trino/dialect.py b/sqlalchemy_trino/dialect.py index eab4d37..3a664b6 100644 --- a/sqlalchemy_trino/dialect.py +++ b/sqlalchemy_trino/dialect.py @@ -90,19 +90,19 @@ def get_columns(self, connection: Connection, def _get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: - schema = schema or self._get_default_schema_name(connection) + catalog_name, schema_name = self._infer_catalog_schema(schema, connection) query = dedent(''' SELECT "column_name", "data_type", "column_default", UPPER("is_nullable") AS "is_nullable" - FROM "information_schema"."columns" + FROM ":catalog"."information_schema"."columns" WHERE "table_schema" = :schema AND "table_name" = :table ORDER BY "ordinal_position" ASC ''').strip() - res = connection.execute(sql.text(query), schema=schema, table=table_name) + res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name, table=table_name) columns = [] for record in res: column = dict( @@ -130,23 +130,24 @@ def get_foreign_keys(self, connection: Connection, return [] def get_schema_names(self, connection: Connection, **kw) -> List[str]: + catalog_name, unused = self._infer_catalog_schema(None, connection) query = dedent(''' SELECT "schema_name" - FROM "information_schema"."schemata" + FROM ":catalog"."information_schema"."schemata" ''').strip() - res = connection.execute(sql.text(query)) + res = connection.execute(sql.text(query), catalog=catalog_name) return [row.schema_name for row in res] def get_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: - schema = schema or self._get_default_schema_name(connection) - if schema is None: + catalog_name, schema_name = self._infer_catalog_schema(schema, connection) + if schema_name is None: raise exc.NoSuchTableError('schema is required') query = dedent(''' SELECT "table_name" - FROM "information_schema"."tables" + FROM ":catalog"."information_schema"."tables" WHERE "table_schema" = :schema ''').strip() - res = connection.execute(sql.text(query), schema=schema) + res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name) return [row.table_name for row in res] def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: @@ -154,15 +155,15 @@ def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) return [] def get_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: - schema = schema or self._get_default_schema_name(connection) - if schema is None: + catalog_name, schema_name = self._infer_catalog_schema(schema, connection) + if schema_name is None: raise exc.NoSuchTableError('schema is required') query = dedent(''' SELECT "table_name" - FROM "information_schema"."views" + FROM ":catalog"."information_schema"."views" WHERE "table_schema" = :schema ''').strip() - res = connection.execute(sql.text(query), schema=schema) + res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name) return [row.table_name for row in res] def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: @@ -170,16 +171,16 @@ def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) return [] def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str: - schema = schema or self._get_default_schema_name(connection) - if schema is None: + catalog_name, schema_name = self._infer_catalog_schema(schema, connection) + if schema_name is None: raise exc.NoSuchTableError('schema is required') query = dedent(''' SELECT "view_definition" - FROM "information_schema"."views" + FROM ":catalog"."information_schema"."views" WHERE "table_schema" = :schema AND "table_name" = :view ''').strip() - res = connection.execute(sql.text(query), schema=schema, view=view_name) + res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name, view=view_name) return res.scalar() def get_indexes(self, connection: Connection, @@ -211,7 +212,7 @@ def get_check_constraints(self, connection: Connection, def get_table_comment(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]: - properties_table = self._get_full_table(f'{table_name}$properties', schema) + properties_table = self._get_full_table(connection, f'{table_name}$properties', schema) query = f'SELECT "comment" FROM {properties_table}' try: res = connection.execute(sql.text(query)) @@ -227,26 +228,27 @@ def get_table_comment(self, connection: Connection, raise def has_schema(self, connection: Connection, schema: str) -> bool: + catalog_name, schema_name = self._infer_catalog_schema(schema, connection) query = dedent(''' SELECT "schema_name" - FROM "information_schema"."schemata" + FROM ":catalog"."information_schema"."schemata" WHERE "schema_name" = :schema ''').strip() - res = connection.execute(sql.text(query), schema=schema) + res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name) return res.first() is not None def has_table(self, connection: Connection, table_name: str, schema: str = None, **kw) -> bool: - schema = schema or self._get_default_schema_name(connection) - if schema is None: + catalog_name, schema_name = self._infer_catalog_schema(schema, connection) + if schema_name is None: return False query = dedent(''' SELECT "table_name" - FROM "information_schema"."tables" + FROM ":catalog"."information_schema"."tables" WHERE "table_schema" = :schema AND "table_name" = :table ''').strip() - res = connection.execute(sql.text(query), schema=schema, table=table_name) + res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name, table=table_name) return res.first() is not None def has_sequence(self, connection: Connection, @@ -265,6 +267,22 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]: dbapi_connection: trino_dbapi.Connection = connection.connection return dbapi_connection.schema + def _infer_catalog_schema(self, schema: str, connection: Connection) -> str: + dbapi_connection: trino_dbapi.Connection = connection.connection + if schema is not None: + t = schema.split('.') + if len(t) == 2: + return (t[0], t[1]) + elif len(t) == 1: + catalog = dbapi_connection.catalog or 'system' + return (catalog, schema) + else: + raise ValueError(f'Bad schema string: "{schema}"') + else: + catalog = dbapi_connection.catalog or 'system' + schema = dbapi_connection.schema + return (catalog, schema) + def do_execute(self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None): cursor.execute(statement, parameters) @@ -289,10 +307,12 @@ def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: 'SERIALIZABLE'] return level_names[dbapi_conn.isolation_level] - def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str: + def _get_full_table(self, connection: Connection, table_name: str, schema: str = None, quote: bool = True) -> str: + catalog_name, schema_name = self._infer_catalog_schema(schema, connection) table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name - if schema: - schema_part = self.identifier_preparer.quote_identifier(schema) if quote else schema - return f'{schema_part}.{table_part}' + catalog_part = self.identifier_preparer.quote_identifier(catalog_name) if quote else catalog_name + if schema_name: + schema_part = self.identifier_preparer.quote_identifier(schema_name) if quote else schema_name + return f'{catalog_part}.{schema_part}.{table_part}' return table_part