Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new overwrite_files S3 write mode for atomic overwrites #3074

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion awswrangler/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def block_waiting_available_thread(seq: Sequence[Future], max_workers: int) -> N

def check_schema_changes(columns_types: dict[str, str], table_input: dict[str, Any] | None, mode: str) -> None:
"""Check schema changes."""
if (table_input is not None) and (mode in ("append", "overwrite_partitions")):
if (table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")):
catalog_cols: dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
for c, t in columns_types.items():
if c not in catalog_cols:
Expand Down
19 changes: 10 additions & 9 deletions awswrangler/catalog/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _update_if_necessary(
if value is not None:
if key not in dic or dic[key] != value:
dic[key] = value
if mode in ("append", "overwrite_partitions"):
if mode in ("append", "overwrite_partitions", "overwrite_files"):
return "update"
return mode

Expand Down Expand Up @@ -150,9 +150,10 @@ def _create_table( # noqa: PLR0912,PLR0915

client_glue = _utils.client(service_name="glue", session=boto3_session)
skip_archive: bool = not catalog_versioning
if mode not in ("overwrite", "append", "overwrite_partitions", "update"):
if mode not in ("overwrite", "append", "overwrite_partitions", "overwrite_files", "update"):
raise exceptions.InvalidArgument(
f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'."
f"{mode} is not a valid mode. It must be 'overwrite', "
f"'append', 'overwrite_partitions' or 'overwrite_files'."
)
args: dict[str, Any] = _catalog_id(
catalog_id=catalog_id,
Expand Down Expand Up @@ -304,7 +305,7 @@ def _create_parquet_table(
_logger.debug("catalog_table_input: %s", catalog_table_input)

table_input: dict[str, Any]
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")):
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")):
table_input = catalog_table_input

is_table_updated = _update_table_input(table_input, columns_types)
Expand Down Expand Up @@ -366,7 +367,7 @@ def _create_orc_table(
_logger.debug("catalog_table_input: %s", catalog_table_input)

table_input: dict[str, Any]
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")):
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")):
table_input = catalog_table_input

is_table_updated = _update_table_input(table_input, columns_types)
Expand Down Expand Up @@ -436,7 +437,7 @@ def _create_csv_table(
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)

table_input: dict[str, Any]
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")):
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")):
table_input = catalog_table_input

is_table_updated = _update_table_input(table_input, columns_types, allow_reorder=False)
Expand Down Expand Up @@ -508,7 +509,7 @@ def _create_json_table(
table_input: dict[str, Any]
if schema_evolution is False:
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")):
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")):
table_input = catalog_table_input

is_table_updated = _update_table_input(table_input, columns_types)
Expand Down Expand Up @@ -1098,7 +1099,7 @@ def create_csv_table(
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
schema_evolution
If True allows schema evolution (new or missing columns), otherwise a exception will be raised.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions"))
(Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files"))
Related tutorial:
https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html
sep
Expand Down Expand Up @@ -1278,7 +1279,7 @@ def create_json_table(
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
schema_evolution
If True allows schema evolution (new or missing columns), otherwise a exception will be raised.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions"))
(Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files"))
Related tutorial:
https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html
serde_library
Expand Down
11 changes: 7 additions & 4 deletions awswrangler/s3/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import logging
import uuid
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, NamedTuple
Expand Down Expand Up @@ -48,7 +47,7 @@ def _extract_dtypes_from_table_input(table_input: dict[str, Any]) -> dict[str, s
def _apply_dtype(
df: pd.DataFrame, dtype: dict[str, str], catalog_table_input: dict[str, Any] | None, mode: str
) -> pd.DataFrame:
if mode in ("append", "overwrite_partitions"):
if mode in ("append", "overwrite_partitions", "overwrite_files"):
if catalog_table_input is not None:
catalog_types: dict[str, str] | None = _extract_dtypes_from_table_input(table_input=catalog_table_input)
if catalog_types is not None:
Expand All @@ -72,6 +71,7 @@ def _validate_args(
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
execution_engine: Enum,
max_rows_by_file: int | None = None,
) -> None:
if df.empty is True:
_logger.warning("Empty DataFrame will be written.")
Expand Down Expand Up @@ -107,6 +107,11 @@ def _validate_args(
raise exceptions.InvalidArgumentValue(
"Please pass a value greater than 1 for the number of buckets for bucketing."
)
elif mode == "overwrite_files" and (max_rows_by_file or bucketing_info):
raise exceptions.InvalidArgumentValue(
"When mode is set to 'overwrite_files', the "
"`max_rows_by_file` and `bucketing_info` arguments cannot be set."
)


class _SanitizeResult(NamedTuple):
Expand Down Expand Up @@ -279,7 +284,6 @@ def write( # noqa: PLR0913
partitions_values: dict[str, list[str]] = {}
mode = "append" if mode is None else mode

filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
s3_client = _utils.client(service_name="s3", session=boto3_session)

Expand Down Expand Up @@ -328,7 +332,6 @@ def write( # noqa: PLR0913
paths = self._write_to_s3(
df,
path=path,
filename_prefix=filename_prefix,
schema=schema,
index=index,
cpus=cpus,
Expand Down
24 changes: 23 additions & 1 deletion awswrangler/s3/_write_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import uuid
from typing import Any, Callable

import boto3
Expand All @@ -18,6 +19,26 @@
_logger: logging.Logger = logging.getLogger(__name__)


def _load_mode_and_filename_prefix(*, mode: str | None, filename_prefix: str | None = None) -> tuple[str, str]:
if mode is None:
mode = "append"

if mode == "overwrite_files":
# In `overwrite_files` mode, we need to create deterministic
# filenames to ensure that the same files are always overwritten:
if filename_prefix is None:
filename_prefix = "data"
random_filename_suffix = ""
mode = "append"
else:
random_filename_suffix = uuid.uuid4().hex

if filename_prefix is None:
filename_prefix = ""
filename_prefix = filename_prefix + random_filename_suffix
return mode, filename_prefix


def _get_bucketing_series(df: pd.DataFrame, bucketing_info: typing.BucketingInfoTuple) -> pd.Series:
bucket_number_series = (
df[bucketing_info[0]]
Expand Down Expand Up @@ -201,7 +222,7 @@ def _to_dataset(
concurrent_partitioning: bool,
df: pd.DataFrame,
path_root: str,
filename_prefix: str,
filename_prefix: str | None,
index: bool,
use_threads: bool | int,
mode: str,
Expand All @@ -212,6 +233,7 @@ def _to_dataset(
) -> tuple[list[str], dict[str, list[str]]]:
path_root = path_root if path_root.endswith("/") else f"{path_root}/"
# Evaluate mode
mode, filename_prefix = _load_mode_and_filename_prefix(mode=mode, filename_prefix=filename_prefix)
if mode not in ["append", "overwrite", "overwrite_partitions"]:
raise exceptions.InvalidArgumentValue(
f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions."
Expand Down
5 changes: 3 additions & 2 deletions awswrangler/s3/_write_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def to_orc(
partition_cols: list[str] | None = None,
bucketing_info: BucketingInfoTuple | None = None,
concurrent_partitioning: bool = False,
mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None,
mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None,
catalog_versioning: bool = False,
schema_evolution: bool = True,
database: str | None = None,
Expand Down Expand Up @@ -414,7 +414,7 @@ def to_orc(
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
schema_evolution
If True allows schema evolution (new or missing columns), otherwise a exception will be raised. True by default.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions"))
(Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files"))
Related tutorial:
https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html
database
Expand Down Expand Up @@ -646,6 +646,7 @@ def to_orc(
columns_comments=columns_comments,
columns_parameters=columns_parameters,
execution_engine=engine.get(),
max_rows_by_file=max_rows_by_file,
)

# Evaluating compression
Expand Down
5 changes: 3 additions & 2 deletions awswrangler/s3/_write_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def to_parquet(
partition_cols: list[str] | None = None,
bucketing_info: BucketingInfoTuple | None = None,
concurrent_partitioning: bool = False,
mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None,
mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None,
catalog_versioning: bool = False,
schema_evolution: bool = True,
database: str | None = None,
Expand Down Expand Up @@ -444,7 +444,7 @@ def to_parquet(
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
schema_evolution
If True allows schema evolution (new or missing columns), otherwise a exception will be raised. True by default.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions"))
(Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files"))
Related tutorial:
https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html
database
Expand Down Expand Up @@ -704,6 +704,7 @@ def to_parquet(
columns_comments=columns_comments,
columns_parameters=columns_parameters,
execution_engine=engine.get(),
max_rows_by_file=max_rows_by_file,
)

# Evaluating compression
Expand Down
11 changes: 4 additions & 7 deletions awswrangler/s3/_write_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import csv
import logging
import uuid
from typing import TYPE_CHECKING, Any, Literal, cast

import boto3
Expand Down Expand Up @@ -98,7 +97,7 @@ def to_csv( # noqa: PLR0912,PLR0915
partition_cols: list[str] | None = None,
bucketing_info: BucketingInfoTuple | None = None,
concurrent_partitioning: bool = False,
mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None,
mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None,
catalog_versioning: bool = False,
schema_evolution: bool = False,
dtype: dict[str, str] | None = None,
Expand Down Expand Up @@ -180,7 +179,7 @@ def to_csv( # noqa: PLR0912,PLR0915
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
schema_evolution
If True allows schema evolution (new or missing columns), otherwise a exception will be raised.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions")). False by default.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files")). False by default.
Related tutorial:
https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html
database
Expand Down Expand Up @@ -469,7 +468,6 @@ def to_csv( # noqa: PLR0912,PLR0915
partitions_values: dict[str, list[str]] = {}
mode = "append" if mode is None else mode

filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex
s3_client = _utils.client(service_name="s3", session=boto3_session)

# Sanitize table to respect Athena's standards
Expand Down Expand Up @@ -661,7 +659,7 @@ def to_json( # noqa: PLR0912,PLR0915
partition_cols: list[str] | None = None,
bucketing_info: BucketingInfoTuple | None = None,
concurrent_partitioning: bool = False,
mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None,
mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None,
catalog_versioning: bool = False,
schema_evolution: bool = True,
dtype: dict[str, str] | None = None,
Expand Down Expand Up @@ -726,7 +724,7 @@ def to_json( # noqa: PLR0912,PLR0915
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
schema_evolution
If True allows schema evolution (new or missing columns), otherwise a exception will be raised.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions"))
(Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files"))
Related tutorial:
https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html
database
Expand Down Expand Up @@ -919,7 +917,6 @@ def to_json( # noqa: PLR0912,PLR0915
partitions_values: dict[str, list[str]] = {}
mode = "append" if mode is None else mode

filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex
s3_client = _utils.client(service_name="s3", session=boto3_session)

# Sanitize table to respect Athena's standards
Expand Down
Loading