Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/athena): handle partition fetching errors #11966

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions metadata-ingestion/src/datahub/ingestion/source/sql/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import StructuredLogLevel
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
Expand All @@ -35,6 +36,7 @@
register_custom_type,
)
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, make_sqlalchemy_uri
from datahub.ingestion.source.sql.sql_report import SQLSourceReport
from datahub.ingestion.source.sql.sql_utils import (
add_table_to_schema_container,
gen_database_container,
Expand All @@ -48,6 +50,15 @@
get_schema_fields_for_sqlalchemy_column,
)

try:
from typing_extensions import override
except ImportError:
_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any])

def override(f: _F, /) -> _F: # noqa: F811
return f


logger = logging.getLogger(__name__)

assert STRUCT, "required type modules are not available"
Expand Down Expand Up @@ -322,12 +333,15 @@ class AthenaSource(SQLAlchemySource):
- Profiling when enabled.
"""

table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}
config: AthenaConfig
report: SQLSourceReport

def __init__(self, config, ctx):
super().__init__(config, ctx, "athena")
self.cursor: Optional[BaseCursor] = None

self.table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}

@classmethod
def create(cls, config_dict, ctx):
config = AthenaConfig.parse_obj(config_dict)
Expand Down Expand Up @@ -452,41 +466,50 @@ def add_table_to_schema_container(
)

# It seems like database/schema filter in the connection string does not work and this to work around that
@override
def get_schema_names(self, inspector: Inspector) -> List[str]:
athena_config = typing.cast(AthenaConfig, self.config)
schemas = inspector.get_schema_names()
if athena_config.database:
return [schema for schema in schemas if schema == athena_config.database]
return schemas

# Overwrite to get partitions
@classmethod
def _casted_partition_key(cls, key: str) -> str:
# We need to cast the partition keys to a VARCHAR, since otherwise
# Athena may throw an error during concatenation / comparison.
return f"CAST({key} as VARCHAR)"

@override
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> List[str]:
partitions = []

athena_config = typing.cast(AthenaConfig, self.config)

if not athena_config.extract_partitions:
return []
) -> Optional[List[str]]:
if not self.config.extract_partitions:
return None

if not self.cursor:
return []
return None

metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
table_name=table, schema_name=schema
)

if metadata.partition_keys:
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)

if not partitions:
return []
partitions = []
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)
if not partitions:
return []

# We create an artiificaial concatenated partition key to be able to query max partition easier
part_concat = "|| '-' ||".join(partitions)
with self.report.report_exc(
message="Failed to extract partition details",
context=f"{schema}.{table}",
level=StructuredLogLevel.WARN,
):
# We create an artifical concatenated partition key to be able to query max partition easier
part_concat = " || '-' || ".join(
self._casted_partition_key(key) for key in partitions
)
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
ret = self.cursor.execute(max_partition_query)
max_partition: Dict[str, str] = {}
Expand All @@ -500,9 +523,8 @@ def get_partitions(
partitions=partitions,
max_partition=max_partition,
)
return partitions

return []
return partitions

# Overwrite to modify the creation of schema fields
def get_schema_fields_for_column(
Expand Down Expand Up @@ -551,7 +573,9 @@ def generate_partition_profiler_query(
if partition and partition.max_partition:
max_partition_filters = []
for key, value in partition.max_partition.items():
max_partition_filters.append(f"CAST({key} as VARCHAR) = '{value}'")
max_partition_filters.append(
f"{self._casted_partition_key(key)} = '{value}'"
)
max_partition = str(partition.max_partition)
return (
max_partition,
Expand Down
45 changes: 42 additions & 3 deletions metadata-ingestion/tests/unit/test_athena_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def test_athena_get_table_properties():
"CreateTime": datetime.now(),
"LastAccessTime": datetime.now(),
"PartitionKeys": [
{"Name": "testKey", "Type": "string", "Comment": "testComment"}
{"Name": "year", "Type": "string", "Comment": "testComment"},
{"Name": "month", "Type": "string", "Comment": "testComment"},
],
"Parameters": {
"comment": "testComment",
Expand All @@ -112,8 +113,18 @@ def test_athena_get_table_properties():
response=table_metadata
)

# Mock partition query results
mock_cursor.execute.return_value.description = [
["year"],
["month"],
]
mock_cursor.execute.return_value.__iter__.return_value = [["2023", "12"]]

ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
source.cursor = mock_cursor

# Test table properties
description, custom_properties, location = source.get_table_properties(
inspector=mock_inspector, table=table, schema=schema
)
Expand All @@ -124,13 +135,35 @@ def test_athena_get_table_properties():
"last_access_time": "2020-04-14 07:00:00",
"location": "s3://testLocation",
"outputformat": "testOutputFormat",
"partition_keys": '[{"name": "testKey", "type": "string", "comment": "testComment"}]',
"partition_keys": '[{"name": "year", "type": "string", "comment": "testComment"}, {"name": "month", "type": "string", "comment": "testComment"}]',
"serde.serialization.lib": "testSerde",
"table_type": "testType",
}

assert location == make_s3_urn("s3://testLocation", "PROD")

# Test partition functionality
partitions = source.get_partitions(
inspector=mock_inspector, schema=schema, table=table
)
assert partitions == ["year", "month"]

# Verify the correct SQL query was generated for partitions
expected_query = """\
select year,month from "test_schema"."test_table$partitions" \
where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \
(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \
from "test_schema"."test_table$partitions")"""
mock_cursor.execute.assert_called_once()
actual_query = mock_cursor.execute.call_args[0][0]
assert actual_query == expected_query

# Verify partition cache was populated correctly
assert source.table_partition_cache[schema][table].partitions == partitions
assert source.table_partition_cache[schema][table].max_partition == {
"year": "2023",
"month": "12",
}


def test_get_column_type_simple_types():
assert isinstance(
Expand Down Expand Up @@ -214,3 +247,9 @@ def test_column_type_complex_combination():
assert isinstance(
result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][1], types.String
)


def test_casted_partition_key():
from datahub.ingestion.source.sql.athena import AthenaSource

assert AthenaSource._casted_partition_key("test_col") == "CAST(test_col as VARCHAR)"
Loading