Skip to content
This repository has been archived by the owner on May 5, 2022. It is now read-only.

add support for schema = <catalog.schema> #42

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 48 additions & 28 deletions sqlalchemy_trino/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,19 @@ def get_columns(self, connection: Connection,

def _get_columns(self, connection: Connection,
table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
schema = schema or self._get_default_schema_name(connection)
catalog_name, schema_name = self._infer_catalog_schema(schema, connection)
query = dedent('''
SELECT
"column_name",
"data_type",
"column_default",
UPPER("is_nullable") AS "is_nullable"
FROM "information_schema"."columns"
FROM ":catalog"."information_schema"."columns"
WHERE "table_schema" = :schema
AND "table_name" = :table
ORDER BY "ordinal_position" ASC
''').strip()
res = connection.execute(sql.text(query), schema=schema, table=table_name)
res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name, table=table_name)
columns = []
for record in res:
column = dict(
Expand Down Expand Up @@ -130,56 +130,57 @@ def get_foreign_keys(self, connection: Connection,
return []

def get_schema_names(self, connection: Connection, **kw) -> List[str]:
catalog_name, unused = self._infer_catalog_schema(None, connection)
query = dedent('''
SELECT "schema_name"
FROM "information_schema"."schemata"
FROM ":catalog"."information_schema"."schemata"
''').strip()
res = connection.execute(sql.text(query))
res = connection.execute(sql.text(query), catalog=catalog_name)
return [row.schema_name for row in res]

def get_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
schema = schema or self._get_default_schema_name(connection)
if schema is None:
catalog_name, schema_name = self._infer_catalog_schema(schema, connection)
if schema_name is None:
raise exc.NoSuchTableError('schema is required')
query = dedent('''
SELECT "table_name"
FROM "information_schema"."tables"
FROM ":catalog"."information_schema"."tables"
WHERE "table_schema" = :schema
''').strip()
res = connection.execute(sql.text(query), schema=schema)
res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name)
return [row.table_name for row in res]

def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
"""Trino has no support for temporary tables. Returns an empty list."""
return []

def get_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
schema = schema or self._get_default_schema_name(connection)
if schema is None:
catalog_name, schema_name = self._infer_catalog_schema(schema, connection)
if schema_name is None:
raise exc.NoSuchTableError('schema is required')
query = dedent('''
SELECT "table_name"
FROM "information_schema"."views"
FROM ":catalog"."information_schema"."views"
WHERE "table_schema" = :schema
''').strip()
res = connection.execute(sql.text(query), schema=schema)
res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name)
return [row.table_name for row in res]

def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
"""Trino has no support for temporary views. Returns an empty list."""
return []

def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str:
schema = schema or self._get_default_schema_name(connection)
if schema is None:
catalog_name, schema_name = self._infer_catalog_schema(schema, connection)
if schema_name is None:
raise exc.NoSuchTableError('schema is required')
query = dedent('''
SELECT "view_definition"
FROM "information_schema"."views"
FROM ":catalog"."information_schema"."views"
WHERE "table_schema" = :schema
AND "table_name" = :view
''').strip()
res = connection.execute(sql.text(query), schema=schema, view=view_name)
res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name, view=view_name)
return res.scalar()

def get_indexes(self, connection: Connection,
Expand Down Expand Up @@ -211,7 +212,7 @@ def get_check_constraints(self, connection: Connection,

def get_table_comment(self, connection: Connection,
table_name: str, schema: str = None, **kw) -> Dict[str, Any]:
properties_table = self._get_full_table(f'{table_name}$properties', schema)
properties_table = self._get_full_table(connection, f'{table_name}$properties', schema)
query = f'SELECT "comment" FROM {properties_table}'
try:
res = connection.execute(sql.text(query))
Expand All @@ -227,26 +228,27 @@ def get_table_comment(self, connection: Connection,
raise

def has_schema(self, connection: Connection, schema: str) -> bool:
catalog_name, schema_name = self._infer_catalog_schema(schema, connection)
query = dedent('''
SELECT "schema_name"
FROM "information_schema"."schemata"
FROM ":catalog"."information_schema"."schemata"
WHERE "schema_name" = :schema
''').strip()
res = connection.execute(sql.text(query), schema=schema)
res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name)
return res.first() is not None

def has_table(self, connection: Connection,
table_name: str, schema: str = None, **kw) -> bool:
schema = schema or self._get_default_schema_name(connection)
if schema is None:
catalog_name, schema_name = self._infer_catalog_schema(schema, connection)
if schema_name is None:
return False
query = dedent('''
SELECT "table_name"
FROM "information_schema"."tables"
FROM ":catalog"."information_schema"."tables"
WHERE "table_schema" = :schema
AND "table_name" = :table
''').strip()
res = connection.execute(sql.text(query), schema=schema, table=table_name)
res = connection.execute(sql.text(query), catalog=catalog_name, schema=schema_name, table=table_name)
return res.first() is not None

def has_sequence(self, connection: Connection,
Expand All @@ -265,6 +267,22 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
dbapi_connection: trino_dbapi.Connection = connection.connection
return dbapi_connection.schema

def _infer_catalog_schema(self, schema: str, connection: Connection) -> str:
dbapi_connection: trino_dbapi.Connection = connection.connection
if schema is not None:
t = schema.split('.')
if len(t) == 2:
return (t[0], t[1])
elif len(t) == 1:
catalog = dbapi_connection.catalog or 'system'
return (catalog, schema)
else:
raise ValueError(f'Bad schema string: "{schema}"')
else:
catalog = dbapi_connection.catalog or 'system'
schema = dbapi_connection.schema
return (catalog, schema)

def do_execute(self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...],
context: DefaultExecutionContext = None):
cursor.execute(statement, parameters)
Expand All @@ -289,10 +307,12 @@ def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str:
'SERIALIZABLE']
return level_names[dbapi_conn.isolation_level]

def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str:
def _get_full_table(self, connection: Connection, table_name: str, schema: str = None, quote: bool = True) -> str:
catalog_name, schema_name = self._infer_catalog_schema(schema, connection)
table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name
if schema:
schema_part = self.identifier_preparer.quote_identifier(schema) if quote else schema
return f'{schema_part}.{table_part}'
catalog_part = self.identifier_preparer.quote_identifier(catalog_name) if quote else catalog_name
if schema_name:
schema_part = self.identifier_preparer.quote_identifier(schema_name) if quote else schema_name
return f'{catalog_part}.{schema_part}.{table_part}'

return table_part