diff --git a/README.md b/README.md index fe38204..2560c52 100644 --- a/README.md +++ b/README.md @@ -74,12 +74,10 @@ We are unable to support the functions that work with geometries. | Function | Description | Supported| | --: | --- | ---| | `latlng_to_cell` | Convert latitude/longitude coordinate to cell ID | ✅| -| `latlng_to_cell_string` | Convert latitude/longitude coordinate to cell ID (returns VARCHAR) | ✅ | | `cell_to_lat` | Convert cell ID to latitude | ✅ | | `cell_to_lng` | Convert cell ID to longitude | ✅ | | `cell_to_latlng` | Convert cell ID to latitude/longitude | ✅ | | `get_resolution` | Get resolution number of cell ID | ✅ | -| `get_base_cell_number` | Get base cell number of cell ID | 🚧| | `str_to_int` | Convert VARCHAR cell ID to UBIGINT | ✅ | | `int_to_str` | Convert BIGINT or UBIGINT cell ID to VARCHAR | ✅ | | `is_valid_cell` | True if this is a valid cell ID | ✅ | @@ -110,7 +108,6 @@ We are unable to support the functions that work with geometries. | `get_directed_edge_destination` | Convert a directed edge ID to destination cell ID | ✅| | `cells_to_directed_edge` | Convert an origin/destination pair to directed edge ID | ✅ | | `are_neighbor_cells` | True if the two cell IDs are directly adjacent | ✅ | -| `directed_edge_to_boundary_wkt` | Convert directed edge ID to linestring WKT | ✅ | | `average_hexagon_area` | Get average area of a hexagon cell at resolution | ✅ | | `cell_area` | Get the area of a cell ID | ✅| | `average_hexagon_edge_length` | Average hexagon edge length at resolution | ✅| @@ -122,3 +119,4 @@ We are unable to support the functions that work with geometries. | `cells_to_multi_polygon_wkt` | Convert a set of cells to multipolygon WKT | 🛑 | | `polygon_wkt_to_cells` | Convert polygon WKT to a set of cells | 🛑 | | `cell_to_boundary_wkt` | Convert cell ID to cell boundary | 🛑 | +| `directed_edge_to_boundary_wkt` | Convert directed edge ID to linestring WKT | 🛑 | diff --git a/polars_h3/__init__.py b/polars_h3/__init__.py index fb3127e..4712b24 100644 --- a/polars_h3/__init__.py +++ b/polars_h3/__init__.py @@ -1,60 +1,58 @@ -from .core.traversal import ( - grid_distance, - grid_ring, - grid_disk, - grid_path_cells, +from .core.edge import ( + are_neighbor_cells, + cells_to_directed_edge, + directed_edge_to_boundary, + directed_edge_to_cells, + get_directed_edge_destination, + get_directed_edge_origin, + is_valid_directed_edge, + origin_to_directed_edges, ) from .core.indexing import ( - latlng_to_cell, - latlng_to_cell_string, cell_to_lat, - cell_to_lng, cell_to_latlng, + cell_to_lng, cell_to_local_ij, + latlng_to_cell, local_ij_to_cell, ) from .core.inspection import ( - get_resolution, - str_to_int, - int_to_str, - is_valid_cell, - is_pentagon, - is_res_class_III, - get_icosahedron_faces, - cell_to_parent, cell_to_center_child, - cell_to_children_size, - cell_to_children, cell_to_child_pos, + cell_to_children, + cell_to_children_size, + cell_to_parent, child_pos_to_cell, compact_cells, + get_icosahedron_faces, + get_resolution, + int_to_str, + is_pentagon, + is_res_class_III, + is_valid_cell, + str_to_int, uncompact_cells, ) -from .core.vertexes import ( - cell_to_vertex, - cell_to_vertexes, - vertex_to_latlng, - is_valid_vertex, -) -from .core.edge import ( - are_neighbor_cells, - cells_to_directed_edge, - is_valid_directed_edge, - get_directed_edge_origin, - get_directed_edge_destination, - directed_edge_to_cells, - origin_to_directed_edges, - directed_edge_to_boundary, -) from .core.metrics import ( - great_circle_distance, average_hexagon_area, + average_hexagon_edge_length, cell_area, edge_length, - average_hexagon_edge_length, get_num_cells, + great_circle_distance, +) +from .core.traversal import ( + grid_disk, + grid_distance, + grid_path_cells, + grid_ring, +) +from .core.vertexes import ( + cell_to_vertex, + cell_to_vertexes, + is_valid_vertex, + vertex_to_latlng, ) - __all__ = [ "grid_distance", @@ -62,7 +60,6 @@ "grid_disk", "grid_path_cells", "latlng_to_cell", - "latlng_to_cell_string", "cell_to_lat", "cell_to_lng", "cell_to_latlng", diff --git a/polars_h3/core/_types.py b/polars_h3/core/_types.py new file mode 100644 index 0000000..9af21b7 --- /dev/null +++ b/polars_h3/core/_types.py @@ -0,0 +1,40 @@ +from typing import Literal, TypeVar, Union + +from typing_extensions import override + +HexResolution = Union[ + Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], int +] + +_T = TypeVar("_T") + + +# Sentinel class used until PEP 0661 is accepted +class NotGiven: + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: + ... + + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NotGivenOr = Union[_T, NotGiven] +NOT_GIVEN = NotGiven() diff --git a/polars_h3/core/edge.py b/polars_h3/core/edge.py index bf5cd58..fa9dc75 100644 --- a/polars_h3/core/edge.py +++ b/polars_h3/core/edge.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING from pathlib import Path +from typing import TYPE_CHECKING import polars as pl from polars.plugins import register_plugin_function - if TYPE_CHECKING: from polars_h3.typing import IntoExprColumn diff --git a/polars_h3/core/indexing.py b/polars_h3/core/indexing.py index dcd1517..a9f9bde 100644 --- a/polars_h3/core/indexing.py +++ b/polars_h3/core/indexing.py @@ -15,20 +15,37 @@ LIB = Path(__file__).parent.parent -def latlng_to_cell_string( - lat: IntoExprColumn, lng: IntoExprColumn, resolution: HexResolution +def latlng_to_cell( + lat: IntoExprColumn, + lng: IntoExprColumn, + resolution: HexResolution, + return_dtype: type[pl.DataType] = pl.UInt64, ) -> pl.Expr: """ - Indexes the location at the specified resolution, providing the index of the cell containing the location. This buckets the geographic point into the H3 grid. + Indexes the location at the specified resolution, providing the index of the cell containing the location. The result dtype is determined by `return_dtype`. """ assert_valid_resolution(resolution) - return register_plugin_function( - args=[lat, lng], - plugin_path=LIB, - function_name="latlng_to_cell_string", - is_elementwise=True, - kwargs={"resolution": resolution}, - ) + + if return_dtype == pl.Utf8: + expr = register_plugin_function( + args=[lat, lng], + plugin_path=LIB, + function_name="latlng_to_cell_string", + is_elementwise=True, + kwargs={"resolution": resolution}, + ) + else: + expr = register_plugin_function( + args=[lat, lng], + plugin_path=LIB, + function_name="latlng_to_cell", + is_elementwise=True, + kwargs={"resolution": resolution}, + ) + if return_dtype != pl.UInt64: + expr = expr.cast(return_dtype) + + return expr def cell_to_lat(cell: IntoExprColumn) -> pl.Expr: diff --git a/polars_h3/core/inspection.py b/polars_h3/core/inspection.py index 347de51..0b5abf0 100644 --- a/polars_h3/core/inspection.py +++ b/polars_h3/core/inspection.py @@ -1,13 +1,12 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Union -from pathlib import Path +from typing import TYPE_CHECKING import polars as pl from polars.plugins import register_plugin_function -from .utils import _assert_valid_resolution, HexResolution +from .utils import HexResolution, assert_valid_resolution if TYPE_CHECKING: from polars_h3.typing import IntoExprColumn @@ -101,7 +100,8 @@ def get_icosahedron_faces(expr: IntoExprColumn) -> pl.Expr: def cell_to_parent( - cell: IntoExprColumn, resolution: Union[HexResolution, None] = None + cell: IntoExprColumn, + resolution: HexResolution, ) -> pl.Expr: """ Returns the parent cell at the specified resolution. If the input cell has resolution r, then parentRes = r - 1 would give the immediate parent, parentRes = r - 2 would give the grandparent, and so on. diff --git a/polars_h3/core/traversal.py b/polars_h3/core/traversal.py index 8543558..ede4fd4 100644 --- a/polars_h3/core/traversal.py +++ b/polars_h3/core/traversal.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING from pathlib import Path +from typing import TYPE_CHECKING import polars as pl from polars.plugins import register_plugin_function - if TYPE_CHECKING: from polars_h3.typing import IntoExprColumn diff --git a/polars_h3/core/utils.py b/polars_h3/core/utils.py index f64e8d1..4e5e3a4 100644 --- a/polars_h3/core/utils.py +++ b/polars_h3/core/utils.py @@ -1,6 +1,4 @@ -from typing import Literal - -HexResolution = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] +from ._types import HexResolution def assert_valid_resolution(resolution: HexResolution) -> None: diff --git a/polars_h3/core/vertexes.py b/polars_h3/core/vertexes.py index bb654c5..f4b2c97 100644 --- a/polars_h3/core/vertexes.py +++ b/polars_h3/core/vertexes.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING from pathlib import Path +from typing import TYPE_CHECKING import polars as pl from polars.plugins import register_plugin_function - if TYPE_CHECKING: from polars_h3.typing import IntoExprColumn diff --git a/src/engine/hierarchy.rs b/src/engine/hierarchy.rs index 146ace4..02e156c 100644 --- a/src/engine/hierarchy.rs +++ b/src/engine/hierarchy.rs @@ -1,8 +1,11 @@ -use super::utils::parse_cell_indices; use h3o::{CellIndex, Resolution}; use polars::prelude::*; use rayon::prelude::*; +use super::utils::{ + cast_list_u64_to_dtype, cast_u64_to_dtype, parse_cell_indices, resolve_target_inner_dtype, +}; + fn get_target_resolution(cell: CellIndex, target_res: Option) -> Option { match target_res { Some(res) => Resolution::try_from(res).ok(), @@ -15,16 +18,16 @@ fn get_target_resolution(cell: CellIndex, target_res: Option) -> Option) -> PolarsResult { + let original_dtype = cell_series.dtype().clone(); let cells = parse_cell_indices(cell_series)?; let parents: UInt64Chunked = cells .into_par_iter() .map(|cell| { cell.and_then(|idx| { - let target_res = if let Some(res) = parent_res { - Resolution::try_from(res).ok() - } else { - idx.resolution().pred() + let target_res = match parent_res { + Some(res) => Resolution::try_from(res).ok(), + None => idx.resolution().pred(), }; target_res.and_then(|res| idx.parent(res)) }) @@ -32,10 +35,11 @@ pub fn cell_to_parent(cell_series: &Series, parent_res: Option) -> PolarsRes }) .collect(); - Ok(parents.into_series()) + cast_u64_to_dtype(&original_dtype, None, parents) } pub fn cell_to_center_child(cell_series: &Series, child_res: Option) -> PolarsResult { + let original_dtype = cell_series.dtype().clone(); let cells = parse_cell_indices(cell_series)?; let center_children: UInt64Chunked = cells @@ -49,7 +53,23 @@ pub fn cell_to_center_child(cell_series: &Series, child_res: Option) -> Pola }) .collect(); - Ok(center_children.into_series()) + let target_dtype = match original_dtype { + DataType::UInt64 => DataType::UInt64, + DataType::Int64 => DataType::Int64, + DataType::String => DataType::String, + _ => { + return Err(PolarsError::ComputeError( + format!( + "Unsupported original dtype for cell_to_center_child: {:?}", + original_dtype + ) + .into(), + )) + }, + }; + + // Cast the UInt64Chunked result to the correct dtype + cast_u64_to_dtype(&original_dtype, Some(&target_dtype), center_children) } pub fn cell_to_children_size(cell_series: &Series, child_res: Option) -> PolarsResult { @@ -70,6 +90,7 @@ pub fn cell_to_children_size(cell_series: &Series, child_res: Option) -> Pol } pub fn cell_to_children(cell_series: &Series, child_res: Option) -> PolarsResult { + let original_dtype = cell_series.dtype().clone(); let cells = parse_cell_indices(cell_series)?; let children: ListChunked = cells @@ -84,7 +105,12 @@ pub fn cell_to_children(cell_series: &Series, child_res: Option) -> PolarsRe }) .collect(); - Ok(children.into_series()) + let children_series = children.into_series(); + + let target_dtype = resolve_target_inner_dtype(&original_dtype)?; + let casted_children = + cast_list_u64_to_dtype(&children_series, &DataType::UInt64, Some(&target_dtype))?; + Ok(casted_children) } pub fn cell_to_child_pos(child_series: &Series, parent_res: u8) -> PolarsResult { @@ -102,15 +128,16 @@ pub fn cell_to_child_pos(child_series: &Series, parent_res: u8) -> PolarsResult< Ok(positions.into_series()) } + pub fn child_pos_to_cell( parent_series: &Series, child_res: u8, pos_series: &Series, ) -> PolarsResult { + let original_dtype = parent_series.dtype().clone(); let parents = parse_cell_indices(parent_series)?; let positions = pos_series.u64()?; - // Convert positions to Vec to ensure we can do parallel iteration let pos_vec: Vec> = positions.into_iter().collect(); let children: UInt64Chunked = parents @@ -125,10 +152,17 @@ pub fn child_pos_to_cell( }) .collect(); - Ok(children.into_series()) + let target_dtype = resolve_target_inner_dtype(&original_dtype)?; + + cast_u64_to_dtype(&original_dtype, Some(&target_dtype), children) } + pub fn compact_cells(cell_series: &Series) -> PolarsResult { - if let DataType::List(_) = cell_series.dtype() { + let original_dtype = cell_series.dtype().clone(); + + // Perform the compaction logic + let out_series = if let DataType::List(_) = cell_series.dtype() { + // Input is already a List column let ca = cell_series.list()?; let cells_vec: Vec<_> = ca.into_iter().collect(); @@ -145,40 +179,58 @@ pub fn compact_cells(cell_series: &Series) -> PolarsResult { PolarsError::ComputeError(format!("Compaction error: {}", e).into()) }) .map(|compacted| { - Series::new( - PlSmallStr::from(""), - compacted.into_iter().map(u64::from).collect::>(), - ) + // Note: `compacted` is a Vec. + // Convert to `u64` and store as a Series of UInt64. + let compacted_u64: Vec = + compacted.into_iter().map(u64::from).collect(); + Series::new(PlSmallStr::from(""), compacted_u64.as_slice()) }) }) .transpose() }) .collect::>()?; - Ok(compacted.into_series()) + compacted.into_series() } else { + // Input is not a list, so we treat it as a single column of cells. let cells = parse_cell_indices(cell_series)?; let cell_vec: Vec<_> = cells.into_iter().flatten().collect(); let compacted = CellIndex::compact(cell_vec) .map_err(|e| PolarsError::ComputeError(format!("Compaction error: {}", e).into()))?; + // Wrap in a single List + let compacted_u64: Vec = compacted.into_iter().map(u64::from).collect(); let compacted_cells: ListChunked = vec![Some(Series::new( PlSmallStr::from(""), - compacted.into_iter().map(u64::from).collect::>(), + compacted_u64.as_slice(), ))] .into_iter() .collect(); - Ok(compacted_cells.into_series()) - } + compacted_cells.into_series() + }; + + // Determine the target inner dtype based on the original column + // If the original was a List, extract its inner type. Otherwise, use the original directly. + let inner_original_dtype = match &original_dtype { + DataType::List(inner) => *inner.clone(), + dt => dt.clone(), + }; + + let target_inner_dtype = resolve_target_inner_dtype(&inner_original_dtype)?; + + cast_list_u64_to_dtype(&out_series, &DataType::UInt64, Some(&target_inner_dtype)) } pub fn uncompact_cells(cell_series: &Series, res: u8) -> PolarsResult { + let original_dtype = cell_series.dtype().clone(); let target_res = Resolution::try_from(res) .map_err(|_| PolarsError::ComputeError("Invalid resolution".into()))?; - if let DataType::List(_) = cell_series.dtype() { + // Perform the uncompact logic + let out_series = if let DataType::List(_) = cell_series.dtype() { + // Input is already a List column let ca = cell_series.list()?; let cells_vec: Vec<_> = ca.into_iter().collect(); @@ -191,28 +243,46 @@ pub fn uncompact_cells(cell_series: &Series, res: u8) -> PolarsResult { let cell_vec: Vec<_> = cells.into_iter().flatten().collect(); let uncompacted = CellIndex::uncompact(cell_vec, target_res); + // Convert the CellIndex result to a UInt64 Series + let uncompacted_u64: Vec = + uncompacted.into_iter().map(u64::from).collect(); Ok(Series::new( PlSmallStr::from(""), - uncompacted.into_iter().map(u64::from).collect::>(), + uncompacted_u64.as_slice(), )) }) .transpose() }) .collect::>()?; - Ok(uncompacted.into_series()) + uncompacted.into_series() } else { + // Input is not a list, treat it as a single column of cells. let cells = parse_cell_indices(cell_series)?; let cell_vec: Vec<_> = cells.into_iter().flatten().collect(); - let uncompacted: ListChunked = vec![Some(Series::new( + + let uncompacted = CellIndex::uncompact(cell_vec, target_res); + let uncompacted_u64: Vec = uncompacted.into_iter().map(u64::from).collect(); + + // Wrap in a single List + let uncompacted_cells: ListChunked = vec![Some(Series::new( PlSmallStr::from(""), - CellIndex::uncompact(cell_vec, target_res) - .map(u64::from) - .collect::>(), + uncompacted_u64.as_slice(), ))] .into_iter() .collect(); - Ok(uncompacted.into_series()) - } + uncompacted_cells.into_series() + }; + + // Determine the target inner dtype based on the original column + let inner_original_dtype = match &original_dtype { + DataType::List(inner) => *inner.clone(), + dt => dt.clone(), + }; + + // Map original inner dtype to the target dtype + let target_inner_dtype = resolve_target_inner_dtype(&inner_original_dtype)?; + // We have a List(UInt64) right now, cast it to List(target_inner_dtype) + cast_list_u64_to_dtype(&out_series, &DataType::UInt64, Some(&target_inner_dtype)) } diff --git a/src/engine/traversal.rs b/src/engine/traversal.rs index 35ae705..b232328 100644 --- a/src/engine/traversal.rs +++ b/src/engine/traversal.rs @@ -1,8 +1,9 @@ -use super::utils::parse_cell_indices; use h3o::CellIndex; use polars::prelude::*; use rayon::prelude::*; +use super::utils::{cast_list_u64_to_dtype, parse_cell_indices, resolve_target_inner_dtype}; + pub fn grid_distance(origin_series: &Series, destination_series: &Series) -> PolarsResult { let origins = parse_cell_indices(origin_series)?; let destinations = parse_cell_indices(destination_series)?; @@ -23,6 +24,7 @@ pub fn grid_distance(origin_series: &Series, destination_series: &Series) -> Pol } pub fn grid_ring(cell_series: &Series, k: u32) -> PolarsResult { + let original_dtype = cell_series.dtype().clone(); let cells = parse_cell_indices(cell_series)?; let rings: ListChunked = cells @@ -38,10 +40,13 @@ pub fn grid_ring(cell_series: &Series, k: u32) -> PolarsResult { }) .collect(); - Ok(rings.into_series()) + let rings_series = rings.into_series(); + let target_inner_dtype = resolve_target_inner_dtype(&original_dtype)?; + cast_list_u64_to_dtype(&rings_series, &DataType::UInt64, Some(&target_inner_dtype)) } pub fn grid_disk(cell_series: &Series, k: u32) -> PolarsResult { + let original_dtype = cell_series.dtype().clone(); let cells = parse_cell_indices(cell_series)?; let disks: ListChunked = cells .into_par_iter() @@ -55,12 +60,16 @@ pub fn grid_disk(cell_series: &Series, k: u32) -> PolarsResult { }) }) .collect(); - Ok(disks.into_series()) + let disks_series = disks.into_series(); + let target_inner_dtype = resolve_target_inner_dtype(&original_dtype)?; + cast_list_u64_to_dtype(&disks_series, &DataType::UInt64, Some(&target_inner_dtype)) } + pub fn grid_path_cells( origin_series: &Series, destination_series: &Series, ) -> PolarsResult { + let original_dtype = origin_series.dtype().clone(); let origins = parse_cell_indices(origin_series)?; let destinations = parse_cell_indices(destination_series)?; @@ -85,7 +94,9 @@ pub fn grid_path_cells( }) .collect(); - Ok(paths.into_series()) + let paths_series = paths.into_series(); + let target_inner_dtype = resolve_target_inner_dtype(&original_dtype)?; + cast_list_u64_to_dtype(&paths_series, &DataType::UInt64, Some(&target_inner_dtype)) } pub fn cell_to_local_ij(cell_series: &Series, origin_series: &Series) -> PolarsResult { diff --git a/src/engine/utils.rs b/src/engine/utils.rs index 4272762..e13c71f 100644 --- a/src/engine/utils.rs +++ b/src/engine/utils.rs @@ -1,4 +1,5 @@ use h3o::CellIndex; +use polars::error::PolarsResult; use polars::prelude::*; pub fn parse_cell_indices(cell_series: &Series) -> PolarsResult>> { @@ -28,3 +29,90 @@ pub fn parse_cell_indices(cell_series: &Series) -> PolarsResult, + result: UInt64Chunked, +) -> PolarsResult { + let final_dtype = target_dtype.unwrap_or(original_dtype); + + match final_dtype { + DataType::UInt64 => Ok(result.into_series()), + DataType::Int64 => result.cast(&DataType::Int64), + DataType::String => { + let utf8: StringChunked = result + .into_iter() + .map(|opt_u| opt_u.map(|u| format!("{:x}", u))) + .collect(); + Ok(utf8.into_series()) + }, + _ => polars_bail!(ComputeError: "Unsupported dtype for H3 result"), + } +} + +pub fn cast_list_u64_to_dtype( + list_series: &Series, + original_dtype: &DataType, + target_dtype: Option<&DataType>, +) -> PolarsResult { + let ca = list_series.list()?; + let final_dtype = target_dtype.unwrap_or(original_dtype); + + let out: ListChunked = ca + .into_iter() + .map(|opt_s| { + opt_s + .map(|s| { + // If the inner list isn't UInt64, cast it to UInt64. + let s_u64 = if s.dtype() != &DataType::UInt64 { + s.cast(&DataType::UInt64)? + } else { + s + }; + + let u64_ca = s_u64.u64()?; + match final_dtype { + DataType::UInt64 => { + // Create an owned version of the UInt64 chunked array before converting. + Ok(u64_ca.to_owned().into_series()) + }, + DataType::Int64 => u64_ca.cast(&DataType::Int64), + DataType::String => { + // Convert each u64 to a hex string. + let utf8: StringChunked = u64_ca + .into_iter() + .map(|opt_u| opt_u.map(|u| format!("{:x}", u))) + .collect(); + Ok(utf8.into_series()) + }, + _ => polars_bail!(ComputeError: "Unsupported dtype for H3 List result"), + } + }) + .transpose() + }) + .collect::>()?; + + Ok(out.into_series()) +} + +pub fn resolve_target_inner_dtype(original_dtype: &DataType) -> PolarsResult { + // If the original was a List, extract its inner type. Otherwise, use the original directly. + let inner_original_dtype = match original_dtype { + DataType::List(inner) => *inner.clone(), + dt => dt.clone(), + }; + + let target_inner_dtype = match inner_original_dtype { + DataType::UInt64 => DataType::UInt64, + DataType::Int64 => DataType::Int64, + DataType::String => DataType::String, + other => { + return Err(PolarsError::ComputeError( + format!("Unsupported inner dtype: {:?}", other).into(), + )) + }, + }; + + Ok(target_inner_dtype) +} diff --git a/src/expressions.rs b/src/expressions.rs index 4bfe27b..325154d 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -26,6 +26,46 @@ fn latlng_list_dtype(input_fields: &[Field]) -> PolarsResult { Ok(field) } +fn map_list_dtype(dt: &DataType) -> PolarsResult { + match dt { + DataType::List(inner) => { + let mapped_inner = map_list_dtype(inner)?; + Ok(DataType::List(Box::new(mapped_inner))) + }, + DataType::UInt64 => Ok(DataType::UInt64), + DataType::Int64 => Ok(DataType::Int64), + DataType::String => Ok(DataType::String), + other => polars_bail!( + ComputeError: "Unsupported input type for dynamic list dtype function: {:?}", + other + ), + } +} + +fn dynamic_list_output_dtype(input_fields: &[Field]) -> PolarsResult { + let input_dtype = &input_fields[0].dtype; + + // map_list_dtype will handle both nested lists and base types + let mapped_dtype = map_list_dtype(input_dtype)?; + + Ok(Field::new(input_fields[0].name.clone(), mapped_dtype)) +} + +fn dynamic_scalar_output_dtype(input_fields: &[Field]) -> PolarsResult { + let input_dtype = &input_fields[0].dtype; + let output_dtype = match input_dtype { + DataType::UInt64 => DataType::UInt64, + DataType::Int64 => DataType::Int64, + DataType::String => DataType::String, + dt => { + polars_bail!(ComputeError: "Unsupported input type: {:?}", dt); + }, + }; + Ok(Field::new(input_fields[0].name.clone(), output_dtype)) +} + +// ===== Indexing ===== // + #[polars_expr(output_type=UInt64)] fn latlng_to_cell(inputs: &[Series], kwargs: LatLngToCellKwargs) -> PolarsResult { let lat_series = &inputs[0]; @@ -63,6 +103,7 @@ fn cell_to_latlng(inputs: &[Series]) -> PolarsResult { } // ===== Inspection ===== // + #[polars_expr(output_type=UInt8)] fn get_resolution(inputs: &[Series]) -> PolarsResult { let cell_series = &inputs[0]; @@ -123,13 +164,13 @@ fn list_uint64_dtype(input_fields: &[Field]) -> PolarsResult { // ===== Hierarchy ===== // -#[polars_expr(output_type=UInt64)] +#[polars_expr(output_type_func=dynamic_scalar_output_dtype)] fn cell_to_parent(inputs: &[Series], kwargs: ResolutionKwargs) -> PolarsResult { let cell_series = &inputs[0]; crate::engine::hierarchy::cell_to_parent(cell_series, kwargs.resolution) } -#[polars_expr(output_type=UInt64)] +#[polars_expr(output_type_func=dynamic_scalar_output_dtype)] fn cell_to_center_child(inputs: &[Series], kwargs: ResolutionKwargs) -> PolarsResult { let cell_series = &inputs[0]; crate::engine::hierarchy::cell_to_center_child(cell_series, kwargs.resolution) @@ -141,7 +182,7 @@ fn cell_to_children_size(inputs: &[Series], kwargs: ResolutionKwargs) -> PolarsR crate::engine::hierarchy::cell_to_children_size(cell_series, kwargs.resolution) } -#[polars_expr(output_type_func=list_uint64_dtype)] +#[polars_expr(output_type_func=dynamic_list_output_dtype)] fn cell_to_children(inputs: &[Series], kwargs: ResolutionKwargs) -> PolarsResult { let cell_series = &inputs[0]; crate::engine::hierarchy::cell_to_children(cell_series, kwargs.resolution) @@ -153,7 +194,7 @@ fn cell_to_child_pos(inputs: &[Series], kwargs: ResolutionKwargs) -> PolarsResul crate::engine::hierarchy::cell_to_child_pos(cell_series, kwargs.resolution.unwrap_or(0)) } -#[polars_expr(output_type=UInt64)] +#[polars_expr(output_type_func=dynamic_scalar_output_dtype)] fn child_pos_to_cell(inputs: &[Series], kwargs: ResolutionKwargs) -> PolarsResult { let parent_series = &inputs[0]; let pos_series = &inputs[1]; @@ -164,7 +205,7 @@ fn child_pos_to_cell(inputs: &[Series], kwargs: ResolutionKwargs) -> PolarsResul ) } -#[polars_expr(output_type_func=list_uint64_dtype)] +#[polars_expr(output_type_func=dynamic_list_output_dtype)] fn compact_cells(inputs: &[Series]) -> PolarsResult { let cell_series = &inputs[0]; crate::engine::hierarchy::compact_cells(cell_series) @@ -188,19 +229,19 @@ fn grid_distance(inputs: &[Series]) -> PolarsResult { crate::engine::traversal::grid_distance(origin_series, destination_series) } -#[polars_expr(output_type_func=list_uint64_dtype)] +#[polars_expr(output_type_func=dynamic_list_output_dtype)] fn grid_ring(inputs: &[Series], kwargs: GridKwargs) -> PolarsResult { let cell_series = &inputs[0]; crate::engine::traversal::grid_ring(cell_series, kwargs.k) } -#[polars_expr(output_type_func=list_uint64_dtype)] +#[polars_expr(output_type_func=dynamic_list_output_dtype)] fn grid_disk(inputs: &[Series], kwargs: GridKwargs) -> PolarsResult { let cell_series = &inputs[0]; crate::engine::traversal::grid_disk(cell_series, kwargs.k) } -#[polars_expr(output_type_func=list_uint64_dtype)] +#[polars_expr(output_type_func=dynamic_list_output_dtype)] fn grid_path_cells(inputs: &[Series]) -> PolarsResult { let origin_series = &inputs[0]; let destination_series = &inputs[1]; diff --git a/tests/test_edge.py b/tests/test_edge.py index 4bc85b3..c7a4d7a 100644 --- a/tests/test_edge.py +++ b/tests/test_edge.py @@ -1,47 +1,89 @@ -import pytest import polars as pl +import pytest -from typing import List, Dict, Union import polars_h3 @pytest.mark.parametrize( - "edge, schema, expected_valid", + "test_params", [ - pytest.param(["2222597fffffffff"], None, False, id="invalid_str_edge"), - pytest.param([0], {"edge": pl.UInt64}, False, id="invalid_int_edge"), - pytest.param(["115283473fffffff"], None, True, id="valid_str_edge"), pytest.param( - [1248204388774707199], {"edge": pl.UInt64}, True, id="valid_int_edge" + { + "input": "2222597fffffffff", + "schema": None, + "output": False, + }, + id="invalid_str_edge", + ), + pytest.param( + { + "input": 0, + "schema": {"edge": pl.UInt64}, + "output": False, + }, + id="invalid_int_edge", + ), + pytest.param( + { + "input": "115283473fffffff", + "schema": None, + "output": True, + }, + id="valid_str_edge", + ), + pytest.param( + { + "input": 1248204388774707199, + "schema": {"edge": pl.UInt64}, + "output": True, + }, + id="valid_int_edge", ), ], ) -def test_is_valid_directed_edge( - edge: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_valid: bool, -): - df = pl.DataFrame({"edge": edge}, schema=schema).with_columns( - valid=polars_h3.is_valid_directed_edge("edge") - ) - assert df["valid"][0] == expected_valid +def test_is_valid_directed_edge(test_params): + df = pl.DataFrame( + {"edge": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns(valid=polars_h3.is_valid_directed_edge("edge")) + assert df["valid"][0] == test_params["output"] @pytest.mark.parametrize( - "origin_cell, schema", + "test_params", [ - pytest.param([599686042433355775], {"h3_cell": pl.UInt64}, id="uint64_input"), - pytest.param([599686042433355775], {"h3_cell": pl.Int64}, id="int64_input"), - pytest.param(["85283473fffffff"], None, id="string_input"), + pytest.param( + { + "input": 599686042433355775, + "schema": {"h3_cell": pl.UInt64}, + "output_length": 6, + }, + id="uint64_input", + ), + pytest.param( + { + "input": 599686042433355775, + "schema": {"h3_cell": pl.Int64}, + "output_length": 6, + }, + id="int64_input", + ), + pytest.param( + { + "input": "85283473fffffff", + "schema": None, + "output_length": 6, + }, + id="string_input", + ), ], ) -def test_origin_to_directed_edges( - origin_cell: List[Union[int, str]], schema: Union[Dict[str, pl.DataType], None] -): - df = pl.DataFrame({"h3_cell": origin_cell}, schema=schema).with_columns( - edges=polars_h3.origin_to_directed_edges("h3_cell") - ) - assert len(df["edges"][0]) == 6 # Each cell should have 6 edges +def test_origin_to_directed_edges(test_params): + df = pl.DataFrame( + {"h3_cell": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns(edges=polars_h3.origin_to_directed_edges("h3_cell")) + assert len(df["edges"][0]) == test_params["output_length"] def test_directed_edge_operations(): @@ -76,52 +118,86 @@ def test_directed_edge_operations(): @pytest.mark.parametrize( - "cell1, cell2, schema, expected_neighbors", + "test_params", [ pytest.param( - [599686042433355775], - [599686030622195711], - {"cell1": pl.UInt64, "cell2": pl.UInt64}, - True, + { + "input_1": 599686042433355775, + "input_2": 599686030622195711, + "schema": {"cell1": pl.UInt64, "cell2": pl.UInt64}, + "output": True, + }, id="neighbor_uint64", ), pytest.param( - [599686042433355775], - [599686029548453887], - {"cell1": pl.UInt64, "cell2": pl.UInt64}, - False, + { + "input_1": 599686042433355775, + "input_2": 599686029548453887, + "schema": {"cell1": pl.UInt64, "cell2": pl.UInt64}, + "output": False, + }, id="not_neighbor_uint64", ), pytest.param( - ["85283473fffffff"], ["85283447fffffff"], None, True, id="neighbor_str" + { + "input_1": "85283473fffffff", + "input_2": "85283447fffffff", + "schema": None, + "output": True, + }, + id="neighbor_str", ), pytest.param( - ["85283473fffffff"], ["85283443fffffff"], None, False, id="not_neighbor_str" + { + "input_1": "85283473fffffff", + "input_2": "85283443fffffff", + "schema": None, + "output": False, + }, + id="not_neighbor_str", ), ], ) -def test_are_neighbor_cells( - cell1: List[Union[int, str]], - cell2: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_neighbors: bool, -): - df = pl.DataFrame({"cell1": cell1, "cell2": cell2}, schema=schema).with_columns( - neighbors=polars_h3.are_neighbor_cells("cell1", "cell2") - ) - assert df["neighbors"][0] == expected_neighbors - +def test_are_neighbor_cells(test_params): + df = pl.DataFrame( + { + "cell1": [test_params["input_1"]], + "cell2": [test_params["input_2"]], + }, + schema=test_params["schema"], + ).with_columns(neighbors=polars_h3.are_neighbor_cells("cell1", "cell2")) + assert df["neighbors"][0] == test_params["output"] -def test_cells_to_directed_edge(): - # Test with integers - df_int = pl.DataFrame( - {"origin": [599686042433355775], "destination": [599686030622195711]}, - schema={"origin": pl.UInt64, "destination": pl.UInt64}, - ).with_columns(edge=polars_h3.cells_to_directed_edge("origin", "destination")) - assert df_int["edge"][0] == 1608492358964346879 - # Test with strings - df_str = pl.DataFrame( - {"origin": ["85283473fffffff"], "destination": ["85283447fffffff"]} +@pytest.mark.parametrize( + "test_params", + [ + pytest.param( + { + "input_1": 599686042433355775, + "input_2": 599686030622195711, + "schema": {"origin": pl.UInt64, "destination": pl.UInt64}, + "output": 1608492358964346879, + }, + id="int_edge", + ), + pytest.param( + { + "input_1": "85283473fffffff", + "input_2": "85283447fffffff", + "schema": None, + "output": 1608492358964346879, + }, + id="string_edge", + ), + ], +) +def test_cells_to_directed_edge(test_params): + df = pl.DataFrame( + { + "origin": [test_params["input_1"]], + "destination": [test_params["input_2"]], + }, + schema=test_params["schema"], ).with_columns(edge=polars_h3.cells_to_directed_edge("origin", "destination")) - assert df_str["edge"][0] == 1608492358964346879 + assert df["edge"][0] == test_params["output"] diff --git a/tests/test_hierarchy.py b/tests/test_hierarchy.py index 6ebcf45..6cba3bb 100644 --- a/tests/test_hierarchy.py +++ b/tests/test_hierarchy.py @@ -2,105 +2,142 @@ FIXME: uncompact stuff """ -from typing import Dict, Union, List -import pytest import polars as pl +import pytest + import polars_h3 @pytest.mark.parametrize( - "h3_cell, schema", + "test_params", [ pytest.param( - [586265647244115967], - {"h3_cell": pl.UInt64}, + { + "input": 586265647244115967, + "output": 581764796395814911, + "schema": {"input": pl.UInt64}, + }, id="uint64_input", ), pytest.param( - [586265647244115967], - {"h3_cell": pl.Int64}, + { + "input": 586265647244115967, + "output": 581764796395814911, + "schema": {"input": pl.Int64}, + }, id="int64_input", ), pytest.param( - ["822d57fffffffff"], - None, + { + "input": "822d57fffffffff", + "output": "812d7ffffffffff", + "schema": None, + }, id="string_input", ), ], ) -def test_cell_to_parent_valid( - h3_cell: List[Union[int, str]], schema: Union[Dict[str, pl.DataType], None] -): - df = pl.DataFrame({"h3_cell": h3_cell}, schema=schema).with_columns( - parent=polars_h3.cell_to_parent("h3_cell", 1) - ) - assert df["parent"].to_list()[0] == 581764796395814911 +def test_cell_to_parent_valid(test_params): + df = pl.DataFrame( + {"input": [test_params["input"]]}, schema=test_params["schema"] + ).with_columns(parent=polars_h3.cell_to_parent("input", 1)) + assert df["parent"].to_list()[0] == test_params["output"] @pytest.mark.parametrize( - "h3_cell, schema", + "test_params", [ pytest.param( - [586265647244115967], - {"h3_cell": pl.UInt64}, + { + "input": 586265647244115967, + "output": 595272305332977663, + "schema": {"input": pl.UInt64}, + }, id="uint64_input", ), pytest.param( - [586265647244115967], - {"h3_cell": pl.Int64}, + { + "input": 586265647244115967, + "output": 595272305332977663, + "schema": {"input": pl.Int64}, + }, id="int64_input", ), pytest.param( - ["822d57fffffffff"], - None, + { + "input": "822d57fffffffff", + "output": "842d501ffffffff", + "schema": None, + }, id="string_input", ), ], ) -def test_cell_to_center_child_valid( - h3_cell: List[Union[int, str]], schema: Union[Dict[str, pl.DataType], None] -): - df = pl.DataFrame({"h3_cell": h3_cell}, schema=schema).with_columns( - child=polars_h3.cell_to_center_child("h3_cell", 4) - ) - assert df["child"].to_list()[0] == 595272305332977663 +def test_cell_to_center_child_valid(test_params): + df = pl.DataFrame( + {"input": [test_params["input"]]}, schema=test_params["schema"] + ).with_columns(child=polars_h3.cell_to_center_child("input", 4)) + assert df["child"].to_list()[0] == test_params["output"] @pytest.mark.parametrize( - "h3_cell, schema", + "test_params", [ pytest.param( - [586265647244115967], - {"h3_cell": pl.UInt64}, + { + "input": 586265647244115967, + "output": [ + 590768765835149311, + 590768834554626047, + 590768903274102783, + 590768971993579519, + 590769040713056255, + 590769109432532991, + 590769178152009727, + ], + "schema": {"input": pl.UInt64}, + }, id="uint64_input", ), pytest.param( - [586265647244115967], - {"h3_cell": pl.Int64}, + { + "input": 586265647244115967, + "output": [ + 590768765835149311, + 590768834554626047, + 590768903274102783, + 590768971993579519, + 590769040713056255, + 590769109432532991, + 590769178152009727, + ], + "schema": {"input": pl.Int64}, + }, id="int64_input", ), pytest.param( - ["822d57fffffffff"], - None, + { + "input": "822d57fffffffff", + "output": [ + "832d50fffffffff", + "832d51fffffffff", + "832d52fffffffff", + "832d53fffffffff", + "832d54fffffffff", + "832d55fffffffff", + "832d56fffffffff", + ], + "schema": None, + }, id="string_input", ), ], ) -def test_cell_to_children_valid( - h3_cell: List[Union[int, str]], schema: Union[Dict[str, pl.DataType], None] -): - df = pl.DataFrame({"h3_cell": h3_cell}, schema=schema).with_columns( - children=polars_h3.cell_to_children("h3_cell", 3) - ) - assert df["children"].to_list()[0] == [ - 590768765835149311, - 590768834554626047, - 590768903274102783, - 590768971993579519, - 590769040713056255, - 590769109432532991, - 590769178152009727, - ] +def test_cell_to_children_valid(test_params): + df = pl.DataFrame( + {"input": [test_params["input"]]}, schema=test_params["schema"] + ).with_columns(children=polars_h3.cell_to_children("input", 3)) + assert df["children"].to_list()[0] == test_params["output"] @pytest.mark.parametrize( diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 8cc7af5..1be538d 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -1,78 +1,109 @@ -import pytest import polars as pl -import polars_h3 -from typing import Optional, Union, Dict, List - +import pytest -def test_latlng_to_cell_valid(): - df = pl.DataFrame({"lat": [0.0], "lng": [0.0]}).with_columns( - h3_cell=polars_h3.latlng_to_cell("lat", "lng", 1) - ) - assert df["h3_cell"][0] == 583031433791012863 +import polars_h3 -def test_latlng_to_cell_string_valid(): - df = pl.DataFrame( - {"lat": [37.7752702151959], "lng": [-122.418307270836]} - ).with_columns( - h3_cell=polars_h3.latlng_to_cell_string("lat", "lng", 9), +@pytest.mark.parametrize( + "input_lat,input_lng,resolution,return_dtype,expected", + [ + (0.0, 0.0, 1, pl.UInt64, 583031433791012863), + (37.7752702151959, -122.418307270836, 9, pl.Utf8, "8928308280fffff"), + ], + ids=["cell_int", "cell_string"], +) +def test_latlng_to_cell_valid(input_lat, input_lng, resolution, return_dtype, expected): + df = pl.DataFrame({"lat": [input_lat], "lng": [input_lng]}).with_columns( + h3_cell=polars_h3.latlng_to_cell( + "lat", "lng", resolution, return_dtype=return_dtype + ) ) - assert df["h3_cell"][0] == "8928308280fffff" + assert df["h3_cell"][0] == expected @pytest.mark.parametrize( - "resolution", + "input_lat,input_lng,resolution", [ - pytest.param(-1, id="negative_resolution"), - pytest.param(30, id="too_high_resolution"), + (0.0, 0.0, -1), + (0.0, 0.0, 30), ], + ids=["negative_resolution", "too_high_resolution"], ) -def test_latlng_to_cell_invalid_resolution(resolution: int): - df = pl.DataFrame({"lat": [0.0], "lng": [0.0]}) - +def test_latlng_to_cell_invalid_resolution(input_lat, input_lng, resolution): + df = pl.DataFrame({"lat": [input_lat], "lng": [input_lng]}) with pytest.raises(ValueError): - df.with_columns(h3_cell=polars_h3.latlng_to_cell("lat", "lng", resolution)) - + df.with_columns( + h3_cell=polars_h3.latlng_to_cell( + "lat", "lng", resolution, return_dtype=pl.UInt64 + ) + ) with pytest.raises(ValueError): df.with_columns( - h3_cell=polars_h3.latlng_to_cell_string("lat", "lng", resolution) + h3_cell=polars_h3.latlng_to_cell( + "lat", "lng", resolution, return_dtype=pl.Utf8 + ) ) @pytest.mark.parametrize( - "lat, lng", + "input_lat,input_lng", [ - pytest.param(37.7752702151959, None, id="null_longitude"), - pytest.param(None, -122.418307270836, id="null_latitude"), - pytest.param(None, None, id="both_null"), + (37.7752702151959, None), + (None, -122.418307270836), + (None, None), ], + ids=["null_longitude", "null_latitude", "both_null"], ) -def test_latlng_to_cell_null_inputs(lat: Optional[float], lng: Optional[float]): - df = pl.DataFrame({"lat": [lat], "lng": [lng]}) - +def test_latlng_to_cell_null_inputs(input_lat, input_lng): + df = pl.DataFrame({"lat": [input_lat], "lng": [input_lng]}) with pytest.raises(pl.exceptions.ComputeError): - df.with_columns(h3_cell=polars_h3.latlng_to_cell("lat", "lng", 9)) - + df.with_columns( + h3_cell=polars_h3.latlng_to_cell("lat", "lng", 9, return_dtype=pl.UInt64) + ) with pytest.raises(pl.exceptions.ComputeError): - df.with_columns(h3_cell=polars_h3.latlng_to_cell_string("lat", "lng", 9)) + df.with_columns( + h3_cell=polars_h3.latlng_to_cell("lat", "lng", 9, return_dtype=pl.Utf8) + ) @pytest.mark.parametrize( - "h3_cell, schema", + "test_params", [ pytest.param( - [599686042433355775], {"int_h3_cell": pl.UInt64}, id="uint64_input" + { + "input": 599686042433355775, + "output_lat": 37.345793375368, + "output_lng": -121.976375972551, + "schema": {"input": pl.UInt64}, + }, + id="uint64_input", + ), + pytest.param( + { + "input": 599686042433355775, + "output_lat": 37.345793375368, + "output_lng": -121.976375972551, + "schema": {"input": pl.Int64}, + }, + id="int64_input", + ), + pytest.param( + { + "input": "85283473fffffff", + "output_lat": 37.345793375368, + "output_lng": -121.976375972551, + "schema": None, + }, + id="string_input", ), - pytest.param([599686042433355775], {"int_h3_cell": pl.Int64}, id="int64_input"), - pytest.param(["85283473fffffff"], None, id="string_input"), ], ) -def test_cell_to_latlng( - h3_cell: List[Union[int, str]], schema: Union[Dict[str, pl.DataType], None] -): - df = pl.DataFrame({"int_h3_cell": h3_cell}, schema=schema).with_columns( - lat=polars_h3.cell_to_lat("int_h3_cell"), - lng=polars_h3.cell_to_lng("int_h3_cell"), +def test_cell_to_latlng(test_params): + df = pl.DataFrame( + {"input": [test_params["input"]]}, schema=test_params["schema"] + ).with_columns( + lat=polars_h3.cell_to_lat("input"), + lng=polars_h3.cell_to_lng("input"), ) - assert pytest.approx(df["lat"][0], 0.00001) == 37.345793375368 - assert pytest.approx(df["lng"][0], 0.00001) == -121.976375972551 + assert pytest.approx(df["lat"][0], 0.00001) == test_params["output_lat"] + assert pytest.approx(df["lng"][0], 0.00001) == test_params["output_lng"] diff --git a/tests/test_inspection.py b/tests/test_inspection.py index 3a48757..3afbc38 100644 --- a/tests/test_inspection.py +++ b/tests/test_inspection.py @@ -1,148 +1,267 @@ -import pytest import polars as pl +import pytest + import polars_h3 -from typing import List, Union, Dict @pytest.mark.parametrize( - "h3_input, schema, expected_resolution", + "test_params", [ pytest.param( - [586265647244115967], {"h3_cell": pl.UInt64}, 2, id="uint64_input" + { + "input": 586265647244115967, + "schema": {"h3_cell": pl.UInt64}, + "output": 2, + }, + id="uint64_input", + ), + pytest.param( + { + "input": 586265647244115967, + "schema": {"h3_cell": pl.Int64}, + "output": 2, + }, + id="int64_input", + ), + pytest.param( + { + "input": "822d57fffffffff", + "schema": None, + "output": 2, + }, + id="string_input", ), - pytest.param([586265647244115967], {"h3_cell": pl.Int64}, 2, id="int64_input"), - pytest.param(["822d57fffffffff"], None, 2, id="string_input"), ], ) -def test_get_resolution( - h3_input: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_resolution: int, -): - df = pl.DataFrame({"h3_cell": h3_input}, schema=schema).with_columns( - resolution=polars_h3.get_resolution("h3_cell") - ) - assert df["resolution"][0] == expected_resolution +def test_get_resolution(test_params): + df = pl.DataFrame( + {"h3_cell": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns(resolution=polars_h3.get_resolution("h3_cell")) + assert df["resolution"][0] == test_params["output"] @pytest.mark.parametrize( - "h3_input, schema, expected_valid", + "test_params", [ pytest.param( - [586265647244115967], {"h3_cell": pl.UInt64}, True, id="valid_uint64" + { + "input": 586265647244115967, + "schema": {"h3_cell": pl.UInt64}, + "output": True, + }, + id="valid_uint64", ), pytest.param( - [586265647244115967], {"h3_cell": pl.Int64}, True, id="valid_int64" + { + "input": 586265647244115967, + "schema": {"h3_cell": pl.Int64}, + "output": True, + }, + id="valid_int64", + ), + pytest.param( + { + "input": "85283473fffffff", + "schema": None, + "output": True, + }, + id="valid_string", + ), + pytest.param( + { + "input": 1234, + "schema": {"h3_cell": pl.UInt64}, + "output": False, + }, + id="invalid_uint64", + ), + pytest.param( + { + "input": 1234, + "schema": {"h3_cell": pl.Int64}, + "output": False, + }, + id="invalid_int64", + ), + pytest.param( + { + "input": "1234", + "schema": None, + "output": False, + }, + id="invalid_string", ), - pytest.param(["85283473fffffff"], None, True, id="valid_string"), - pytest.param([1234], {"h3_cell": pl.UInt64}, False, id="invalid_uint64"), - pytest.param([1234], {"h3_cell": pl.Int64}, False, id="invalid_int64"), - pytest.param(["1234"], None, False, id="invalid_string"), ], ) -def test_is_valid_cell( - h3_input: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_valid: bool, -): - df = pl.DataFrame({"h3_cell": h3_input}, schema=schema).with_columns( - valid=polars_h3.is_valid_cell("h3_cell") - ) - assert df["valid"][0] == expected_valid +def test_is_valid_cell(test_params): + df = pl.DataFrame( + {"h3_cell": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns(valid=polars_h3.is_valid_cell("h3_cell")) + assert df["valid"][0] == test_params["output"] @pytest.mark.parametrize( - "h3_int_input, expected_str", + "test_params", [ - pytest.param([605035864166236159], "86584e9afffffff", id="number_1"), - pytest.param([581698825698148351], "8129bffffffffff", id="number_2"), - pytest.param([626682153101213695], "8b26c1912acbfff", id="number_3"), - pytest.param([1], None, id="invalid_cell"), + pytest.param( + { + "input": 605035864166236159, + "output": "86584e9afffffff", + }, + id="number_1", + ), + pytest.param( + { + "input": 581698825698148351, + "output": "8129bffffffffff", + }, + id="number_2", + ), + pytest.param( + { + "input": 626682153101213695, + "output": "8b26c1912acbfff", + }, + id="number_3", + ), + pytest.param( + { + "input": 1, + "output": None, + }, + id="invalid_cell", + ), ], ) -def test_int_to_str_conversion(h3_int_input: List[int], expected_str: str): +def test_int_to_str_conversion(test_params): # Test UInt64 df_uint = pl.DataFrame( - {"h3_cell": h3_int_input}, schema={"h3_cell": pl.UInt64} - ).with_columns(polars_h3.int_to_str("h3_cell").alias("h3_str")) - assert df_uint["h3_str"].to_list()[0] == expected_str + {"h3_cell": [test_params["input"]]}, + schema={"h3_cell": pl.UInt64}, + ).with_columns(h3_str=polars_h3.int_to_str("h3_cell")) + assert df_uint["h3_str"][0] == test_params["output"] # Test Int64 df_int = pl.DataFrame( - {"h3_cell": h3_int_input}, schema={"h3_cell": pl.Int64} + {"h3_cell": [test_params["input"]]}, + schema={"h3_cell": pl.Int64}, ).with_columns(h3_str=polars_h3.int_to_str("h3_cell")) - assert df_int["h3_str"][0] == expected_str + assert df_int["h3_str"][0] == test_params["output"] @pytest.mark.parametrize( - "h3_str_input, expected_int", + "test_params", [ - pytest.param(["86584e9afffffff"], 605035864166236159, id="number_1"), - pytest.param(["8129bffffffffff"], 581698825698148351, id="number_2"), - pytest.param(["8b26c1912acbfff"], 626682153101213695, id="number_3"), - pytest.param(["sergey"], None, id="invalid_cell"), + pytest.param( + { + "input": "86584e9afffffff", + "output": 605035864166236159, + }, + id="number_1", + ), + pytest.param( + { + "input": "8129bffffffffff", + "output": 581698825698148351, + }, + id="number_2", + ), + pytest.param( + { + "input": "8b26c1912acbfff", + "output": 626682153101213695, + }, + id="number_3", + ), + pytest.param( + { + "input": "sergey", + "output": None, + }, + id="invalid_cell", + ), ], ) -def test_str_to_int_conversion(h3_str_input: List[str], expected_int: int): - # Test UInt64 - df_uint = pl.DataFrame({"h3_cell": h3_str_input}).with_columns( - polars_h3.str_to_int("h3_cell").alias("h3_int") +def test_str_to_int_conversion(test_params): + # Test with no schema specified + df_uint = pl.DataFrame({"h3_cell": [test_params["input"]]}).with_columns( + h3_int=polars_h3.str_to_int("h3_cell") ) - assert df_uint["h3_int"].to_list()[0] == expected_int + assert df_uint["h3_int"][0] == test_params["output"] - # Test Int64 - df_int = pl.DataFrame({"h3_cell": h3_str_input}).with_columns( + # Test with Int64 schema + df_int = pl.DataFrame({"h3_cell": [test_params["input"]]}).with_columns( h3_int=polars_h3.str_to_int("h3_cell") ) - assert df_int["h3_int"][0] == expected_int + assert df_int["h3_int"][0] == test_params["output"] -def test_is_pentagon(): +@pytest.mark.parametrize( + "test_params", + [ + pytest.param( + { + "inputs": ["821c07fffffffff", "85283473fffffff"], + "outputs": [True, False], + "schema": None, + }, + id="string_input", + ), + pytest.param( + { + "inputs": [585961082523222015, 599686042433355775], + "outputs": [True, False], + "schema": {"h3_cell": pl.UInt64}, + }, + id="int_input", + ), + ], +) +def test_is_pentagon(test_params): df = pl.DataFrame( - { - "h3_cell": [ - "821c07fffffffff", # pentagon - "85283473fffffff", # not pentagon (regular hexagon) - ] - } + {"h3_cell": test_params["inputs"]}, + schema=test_params["schema"], ).with_columns(is_pent=polars_h3.is_pentagon("h3_cell")) - assert df["is_pent"].to_list() == [True, False] - - df_int = pl.DataFrame( - { - "h3_cell": [585961082523222015, 599686042433355775], - } - ).with_columns(is_pent=polars_h3.is_pentagon("h3_cell")) - assert df_int["is_pent"].to_list() == [True, False] + assert df["is_pent"].to_list() == test_params["outputs"] -def test_is_res_class_III(): - # Resolution 1 (class III) and 2 (not class III) cells +@pytest.mark.parametrize( + "test_params", + [ + pytest.param( + { + "inputs": [ + "81623ffffffffff", # res 1 - should be class III + "822d57fffffffff", # res 2 - should not be class III + "847c35fffffffff", + ], + "outputs": [True, False, False], + "schema": None, + }, + id="string_input", + ), + pytest.param( + { + "inputs": [ + 582692784209657855, # res 1 cell - should be class III + 586265647244115967, # res 2 cell - should not be class III + 596660292734156799, + ], + "outputs": [True, False, False], + "schema": {"h3_cell": pl.UInt64}, + }, + id="int_input", + ), + ], +) +def test_is_res_class_III(test_params): df = pl.DataFrame( - { - "h3_cell": [ - "81623ffffffffff", # res 1 - should be class III - "822d57fffffffff", # res 2 - should not be class III - "847c35fffffffff", - ] - } + {"h3_cell": test_params["inputs"]}, + schema=test_params["schema"], ).with_columns(is_class_3=polars_h3.is_res_class_III("h3_cell")) - - assert df["is_class_3"].to_list() == [True, False, False] - - # Test with integer representation too - df_int = pl.DataFrame( - { - "h3_cell": [ - 582692784209657855, # res 1 cell - should be class III - 586265647244115967, # res 2 cell - should not be class III - 596660292734156799, - ] - }, - schema={"h3_cell": pl.UInt64}, - ).with_columns(is_class_3=polars_h3.is_res_class_III("h3_cell")) - - assert df_int["is_class_3"].to_list() == [True, False, False] + assert df["is_class_3"].to_list() == test_params["outputs"] def test_str_to_int_invalid(): @@ -153,57 +272,95 @@ def test_str_to_int_invalid(): @pytest.mark.parametrize( - "h3_input, schema, expected_faces", + "test_params", [ pytest.param( - [599686042433355775], {"h3_cell": pl.UInt64}, [7], id="single_face_uint64" + { + "input": 599686042433355775, + "schema": {"h3_cell": pl.UInt64}, + "output": [7], + }, + id="single_face_uint64", + ), + pytest.param( + { + "input": 599686042433355775, + "schema": {"h3_cell": pl.Int64}, + "output": [7], + }, + id="single_face_int64", ), pytest.param( - [599686042433355775], {"h3_cell": pl.Int64}, [7], id="single_face_int64" + { + "input": "85283473fffffff", + "schema": None, + "output": [7], + }, + id="single_face_string", ), - pytest.param(["85283473fffffff"], None, [7], id="single_face_string"), pytest.param( - [576988517884755967], - {"h3_cell": pl.UInt64}, - [1, 6, 11, 7, 2], + { + "input": 576988517884755967, + "schema": {"h3_cell": pl.UInt64}, + "output": [1, 6, 11, 7, 2], + }, id="multiple_faces_uint64", ), pytest.param( - [576988517884755967], - {"h3_cell": pl.Int64}, - [1, 6, 11, 7, 2], + { + "input": 576988517884755967, + "schema": {"h3_cell": pl.Int64}, + "output": [1, 6, 11, 7, 2], + }, id="multiple_faces_int64", ), pytest.param( - ["801dfffffffffff"], None, [1, 6, 11, 7, 2], id="multiple_faces_string" + { + "input": "801dfffffffffff", + "schema": None, + "output": [1, 6, 11, 7, 2], + }, + id="multiple_faces_string", ), ], ) -def test_get_icosahedron_faces( - h3_input: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_faces: List[int], -): - df = pl.DataFrame({"h3_cell": h3_input}, schema=schema).with_columns( - faces=polars_h3.get_icosahedron_faces("h3_cell").list.sort() - ) - assert df["faces"][0].to_list() == sorted(expected_faces) +def test_get_icosahedron_faces(test_params): + df = pl.DataFrame( + {"h3_cell": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns(faces=polars_h3.get_icosahedron_faces("h3_cell").list.sort()) + assert df["faces"][0].to_list() == sorted(test_params["output"]) @pytest.mark.parametrize( - "h3_input, schema", + "test_params", [ pytest.param( - [18446744073709551615], {"h3_cell": pl.UInt64}, id="invalid_uint64" + { + "input": 18446744073709551615, + "schema": {"h3_cell": pl.UInt64}, + }, + id="invalid_uint64", + ), + pytest.param( + { + "input": 9223372036854775807, + "schema": {"h3_cell": pl.Int64}, + }, + id="invalid_int64", + ), + pytest.param( + { + "input": "7fffffffffffffff", + "schema": None, + }, + id="invalid_string", ), - pytest.param([9223372036854775807], {"h3_cell": pl.Int64}, id="invalid_int64"), - pytest.param(["7fffffffffffffff"], None, id="invalid_string"), ], ) -def test_get_icosahedron_faces_invalid( - h3_input: List[Union[int, str]], schema: Union[Dict[str, pl.DataType], None] -): - df = pl.DataFrame({"h3_cell": h3_input}, schema=schema).with_columns( - faces=polars_h3.get_icosahedron_faces("h3_cell") - ) +def test_get_icosahedron_faces_invalid(test_params): + df = pl.DataFrame( + {"h3_cell": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns(faces=polars_h3.get_icosahedron_faces("h3_cell")) assert df["faces"][0] is None diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f6f1abe..81f5fe1 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,141 +1,261 @@ -import pytest import polars as pl +import pytest + import polars_h3 -from typing import Union, Dict @pytest.mark.parametrize( - "lat1, lng1, lat2, lng2, unit, expected_distance", + "test_params", [ - pytest.param(40.7128, -74.0060, 40.7128, -74.0060, "km", 0, id="same_point_km"), pytest.param( - 40.7128, -74.0060, 42.3601, -71.0589, "km", 306.108, id="diff_points_km" + { + "input_lat1": 40.7128, + "input_lng1": -74.0060, + "input_lat2": 40.7128, + "input_lng2": -74.0060, + "unit": "km", + "output": 0, + }, + id="same_point_km", ), pytest.param( - 40.7128, -74.0060, 42.3601, -71.0589, "m", 306108, id="diff_points_m" + { + "input_lat1": 40.7128, + "input_lng1": -74.0060, + "input_lat2": 42.3601, + "input_lng2": -71.0589, + "unit": "km", + "output": 306.108, + }, + id="diff_points_km", ), pytest.param( - 40.7128, - -74.0060, - 34.0522, - -118.2437, - "km", - 3936.155, + { + "input_lat1": 40.7128, + "input_lng1": -74.0060, + "input_lat2": 42.3601, + "input_lng2": -71.0589, + "unit": "m", + "output": 306108, + }, + id="diff_points_m", + ), + pytest.param( + { + "input_lat1": 40.7128, + "input_lng1": -74.0060, + "input_lat2": 34.0522, + "input_lng2": -118.2437, + "unit": "km", + "output": 3936.155, + }, id="large_distance_km", ), pytest.param( - 40.7128, -74.0060, 34.0522, -118.2437, "m", 3936155, id="large_distance_m" + { + "input_lat1": 40.7128, + "input_lng1": -74.0060, + "input_lat2": 34.0522, + "input_lng2": -118.2437, + "unit": "m", + "output": 3936155, + }, + id="large_distance_m", ), ], ) -def test_great_circle_distance( - lat1: float, - lng1: float, - lat2: float, - lng2: float, - unit: Union[str, None], - expected_distance: Union[float, None], -): +def test_great_circle_distance(test_params): df = pl.DataFrame( { - "lat1": [lat1], - "lng1": [lng1], - "lat2": [lat2], - "lng2": [lng2], + "lat1": [test_params["input_lat1"]], + "lng1": [test_params["input_lng1"]], + "lat2": [test_params["input_lat2"]], + "lng2": [test_params["input_lng2"]], } ).with_columns( - distance=polars_h3.great_circle_distance("lat1", "lng1", "lat2", "lng2", unit) + distance=polars_h3.great_circle_distance( + "lat1", "lng1", "lat2", "lng2", test_params["unit"] + ) ) - - if expected_distance is None: + if test_params["output"] is None: assert df["distance"][0] is None else: - assert pytest.approx(df["distance"][0], rel=1e-3) == expected_distance + assert pytest.approx(df["distance"][0], rel=1e-3) == test_params["output"] @pytest.mark.parametrize( - "resolution, unit, expected_area", + "test_params", [ - pytest.param(0, "km^2", 4357449.416078383, id="res0_km2"), - pytest.param(1, "km^2", 609788.4417941332, id="res1_km2"), - pytest.param(9, "m^2", 105332.51342720671, id="res0_m2"), - pytest.param(10, "m^2", 15047.50190766435, id="res1_m2"), - # pytest.param(-1, "km^2", None, id="invalid_res"), # should be able to handle, currently has silent strange behavior + pytest.param( + { + "input": 0, + "unit": "km^2", + "output": 4357449.416078383, + }, + id="res0_km2", + ), + pytest.param( + { + "input": 1, + "unit": "km^2", + "output": 609788.4417941332, + }, + id="res1_km2", + ), + pytest.param( + { + "input": 9, + "unit": "m^2", + "output": 105332.51342720671, + }, + id="res0_m2", + ), + pytest.param( + { + "input": 10, + "unit": "m^2", + "output": 15047.50190766435, + }, + id="res1_m2", + ), + # pytest.param( + # { + # "input": -1, + # "unit": "km^2", + # "output": None, + # }, + # id="invalid_res", + # ), ], ) -def test_average_hexagon_area( - resolution: int, unit: str, expected_area: Union[float, None] -): - df = pl.DataFrame({"resolution": [resolution]}).with_columns( - polars_h3.average_hexagon_area(pl.col("resolution"), unit).alias("area") +def test_average_hexagon_area(test_params): + df = pl.DataFrame({"resolution": [test_params["input"]]}).with_columns( + polars_h3.average_hexagon_area(pl.col("resolution"), test_params["unit"]).alias( + "area" + ) ) - if expected_area is None: + if test_params["output"] is None: assert df["area"][0] is None else: - assert pytest.approx(df["area"][0], rel=1e-2) == expected_area + assert pytest.approx(df["area"][0], rel=1e-2) == test_params["output"] @pytest.mark.parametrize( - "h3_cell, schema, unit, expected_area", + "test_params", [ pytest.param( - "8928308280fffff", None, "km^2", 0.1093981886464832, id="string_km2" + { + "input": "8928308280fffff", + "schema": None, + "unit": "km^2", + "output": 0.1093981886464832, + }, + id="string_km2", ), pytest.param( - "8928308280fffff", None, "m^2", 109398.18864648319, id="string_m2" + { + "input": "8928308280fffff", + "schema": None, + "unit": "m^2", + "output": 109398.18864648319, + }, + id="string_m2", ), pytest.param( - 586265647244115967, - {"h3_cell": pl.UInt64}, - "km^2", - 85321.69572540345, + { + "input": 586265647244115967, + "schema": {"h3_cell": pl.UInt64}, + "unit": "km^2", + "output": 85321.69572540345, + }, id="uint64_km2", ), pytest.param( - 586265647244115967, - {"h3_cell": pl.Int64}, - "km^2", - 85321.69572540345, + { + "input": 586265647244115967, + "schema": {"h3_cell": pl.Int64}, + "unit": "km^2", + "output": 85321.69572540345, + }, id="int64_km2", ), - pytest.param("fffffffffffffff", None, "km^2", None, id="invalid_cell"), + pytest.param( + { + "input": "fffffffffffffff", + "schema": None, + "unit": "km^2", + "output": None, + }, + id="invalid_cell", + ), ], ) -def test_hexagon_area( - h3_cell: Union[str, int], - schema: Union[Dict[str, pl.DataType], None], - unit: str, - expected_area: Union[float, None], -): - df = pl.DataFrame({"h3_cell": [h3_cell]}, schema=schema).with_columns( - area=polars_h3.cell_area(pl.col("h3_cell"), unit) - ) - if expected_area is None: +def test_hexagon_area(test_params): + df = pl.DataFrame( + {"h3_cell": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns(area=polars_h3.cell_area(pl.col("h3_cell"), test_params["unit"])) + if test_params["output"] is None: assert df["area"][0] is None else: - assert pytest.approx(df["area"][0], rel=1e-9) == expected_area + assert pytest.approx(df["area"][0], rel=1e-9) == test_params["output"] @pytest.mark.parametrize( - "resolution, unit, expected_length", + "test_params", [ - pytest.param(0, "km", 1107.712591, id="res0_km"), - pytest.param(1, "km", 418.6760055, id="res1_km"), - pytest.param(0, "m", 1107712.591, id="res0_m"), - pytest.param(1, "m", 418676.0, id="res1_m"), - # pytest.param(-1, "km", None, id="invalid_res"), + pytest.param( + { + "input": 0, + "unit": "km", + "output": 1107.712591, + }, + id="res0_km", + ), + pytest.param( + { + "input": 1, + "unit": "km", + "output": 418.6760055, + }, + id="res1_km", + ), + pytest.param( + { + "input": 0, + "unit": "m", + "output": 1107712.591, + }, + id="res0_m", + ), + pytest.param( + { + "input": 1, + "unit": "m", + "output": 418676.0, + }, + id="res1_m", + ), + # pytest.param( + # { + # "input": -1, + # "unit": "km", + # "output": None, + # }, + # id="invalid_res", + # ), ], ) -def test_average_hexagon_edge_length( - resolution: int, unit: str, expected_length: Union[float, None] -): - df = pl.DataFrame({"resolution": [resolution]}).with_columns( - length=polars_h3.average_hexagon_edge_length(pl.col("resolution"), unit) +def test_average_hexagon_edge_length(test_params): + df = pl.DataFrame({"resolution": [test_params["input"]]}).with_columns( + length=polars_h3.average_hexagon_edge_length( + pl.col("resolution"), test_params["unit"] + ) ) - if expected_length is None: + if test_params["output"] is None: assert df["length"][0] is None else: - assert pytest.approx(df["length"][0], rel=1e-3) == expected_length + assert pytest.approx(df["length"][0], rel=1e-3) == test_params["output"] # @pytest.mark.parametrize( diff --git a/tests/test_traversal.py b/tests/test_traversal.py index fb958af..6dce805 100644 --- a/tests/test_traversal.py +++ b/tests/test_traversal.py @@ -1,66 +1,78 @@ -import pytest import polars as pl +import pytest + import polars_h3 -from typing import Union, List, Dict @pytest.mark.parametrize( - "h3_cell, schema", + "test_params", [ - ( - [622054503267303423], - None, + pytest.param( + { + "input": 622054503267303423, + "schema": None, + "output_disk_radius_0": [622054503267303423], + "output_disk_radius_1": [ + 622054502770606079, + 622054502770835455, + 622054502770900991, + 622054503267205119, + 622054503267237887, + 622054503267270655, + 622054503267303423, + ], + }, + id="int_no_schema", ), - ( - [622054503267303423], - {"h3_cell": pl.UInt64}, + pytest.param( + { + "input": 622054503267303423, + "schema": {"input": pl.UInt64}, + "output_disk_radius_0": [622054503267303423], + "output_disk_radius_1": [ + 622054502770606079, + 622054502770835455, + 622054502770900991, + 622054503267205119, + 622054503267237887, + 622054503267270655, + 622054503267303423, + ], + }, + id="uint64_with_schema", ), - ( - ["8a1fb46622dffff"], - None, + pytest.param( + { + "input": "8a1fb46622dffff", + "schema": None, + "output_disk_radius_0": ["8a1fb46622dffff"], + "output_disk_radius_1": [ + "8a1fb464492ffff", + "8a1fb4644967fff", + "8a1fb4644977fff", + "8a1fb46622c7fff", + "8a1fb46622cffff", + "8a1fb46622d7fff", + "8a1fb46622dffff", + ], + }, + id="string_input", ), ], ) -def test_grid_disk( - h3_cell: List[Union[int, str]], schema: Union[Dict[str, pl.DataType], None] -): - df = pl.DataFrame({"h3_cell": h3_cell}, schema=schema).with_columns( - polars_h3.grid_disk("h3_cell", 0).list.sort().alias("disk_radius_0"), - polars_h3.grid_disk("h3_cell", 1).list.sort().alias("disk_radius_1"), - polars_h3.grid_disk("h3_cell", 2).list.sort().alias("disk_radius_2"), +def test_grid_disk(test_params): + df = pl.DataFrame( + {"input": [test_params["input"]]}, schema=test_params["schema"] + ).with_columns( + polars_h3.grid_disk("input", 0).list.sort().alias("disk_radius_0"), + polars_h3.grid_disk("input", 1).list.sort().alias("disk_radius_1"), + polars_h3.grid_disk("input", 2).list.sort().alias("disk_radius_2"), ) - assert df["disk_radius_0"].to_list()[0] == [622054503267303423] - assert df["disk_radius_1"].to_list()[0] == [ - 622054502770606079, - 622054502770835455, - 622054502770900991, - 622054503267205119, - 622054503267237887, - 622054503267270655, - 622054503267303423, - ] - assert df["disk_radius_2"].to_list()[0] == [ - 622054502770442239, - 622054502770475007, - 622054502770573311, - 622054502770606079, - 622054502770704383, - 622054502770769919, - 622054502770835455, - 622054502770868223, - 622054502770900991, - 622054503266975743, - 622054503267205119, - 622054503267237887, - 622054503267270655, - 622054503267303423, - 622054503267336191, - 622054503267368959, - 622054503267401727, - 622054503286931455, - 622054503287062527, - ] + assert df["disk_radius_0"].to_list()[0] == test_params["output_disk_radius_0"] + assert df["disk_radius_1"].to_list()[0] == test_params["output_disk_radius_1"] + + assert len(df["disk_radius_2"].to_list()[0]) == 19 def test_grid_disk_raises_invalid_k(): @@ -71,282 +83,413 @@ def test_grid_disk_raises_invalid_k(): @pytest.mark.parametrize( - "h3_cell_1, h3_cell_2, schema, expected_path", + "test_params", + [ + pytest.param( + { + "input": "8a1fb46622dffff", + "k": 0, + "schema": None, + "output": ["8a1fb46622dffff"], + }, + id="string_k0", + ), + pytest.param( + { + "input": "8a1fb46622dffff", + "k": 1, + "schema": None, + "output": [ + "8a1fb464492ffff", + "8a1fb4644967fff", + "8a1fb4644977fff", + "8a1fb46622c7fff", + "8a1fb46622cffff", + "8a1fb46622d7fff", + ], + }, + id="string_k1", + ), + pytest.param( + { + "input": 622054503267303423, + "k": 0, + "schema": {"input": pl.UInt64}, + "output": [622054503267303423], + }, + id="uint64_k1", + ), + pytest.param( + { + "input": 622054503267303423, + "k": 1, + "schema": {"input": pl.UInt64}, + "output": [ + 622054502770606079, + 622054502770835455, + 622054502770900991, + 622054503267205119, + 622054503267237887, + 622054503267270655, + ], + }, + id="uint64_k2", + ), + ], +) +def test_grid_ring(test_params): + df = pl.DataFrame( + {"input": [test_params["input"]]}, schema=test_params["schema"] + ).with_columns( + polars_h3.grid_ring("input", test_params["k"]).list.sort().alias("ring") + ) + + assert df["ring"].to_list()[0] == sorted(test_params["output"]) + + +@pytest.mark.parametrize( + "test_params", [ pytest.param( - [605035864166236159], - [605035864166236159], - {"h3_cell_1": pl.UInt64, "h3_cell_2": pl.UInt64}, - [ - 605035864166236159, - ], + { + "input_1": 605035864166236159, + "input_2": 605035864166236159, + "schema": {"input_1": pl.UInt64, "input_2": pl.UInt64}, + "output": [605035864166236159], + }, id="single_path", ), pytest.param( - [605035864166236159], - [605034941150920703], - {"h3_cell_1": pl.UInt64, "h3_cell_2": pl.UInt64}, - [ - 605035864166236159, - 605035861750317055, - 605035861347663871, - 605035862018752511, - 605034941419356159, - 605034941150920703, - ], + { + "input_1": 605035864166236159, + "input_2": 605034941150920703, + "schema": {"input_1": pl.UInt64, "input_2": pl.UInt64}, + "output": [ + 605035864166236159, + 605035861750317055, + 605035861347663871, + 605035862018752511, + 605034941419356159, + 605034941150920703, + ], + }, id="valid_path_uint64", ), pytest.param( - [605035864166236159], - [605034941150920703], - {"h3_cell_1": pl.Int64, "h3_cell_2": pl.Int64}, - [ - 605035864166236159, - 605035861750317055, - 605035861347663871, - 605035862018752511, - 605034941419356159, - 605034941150920703, - ], + { + "input_1": 605035864166236159, + "input_2": 605034941150920703, + "schema": {"input_1": pl.Int64, "input_2": pl.Int64}, + "output": [ + 605035864166236159, + 605035861750317055, + 605035861347663871, + 605035862018752511, + 605034941419356159, + 605034941150920703, + ], + }, id="valid_path_int64", ), pytest.param( - ["86584e9afffffff"], - ["8658412c7ffffff"], - None, - [ - 605035864166236159, - 605035861750317055, - 605035861347663871, - 605035862018752511, - 605034941419356159, - 605034941150920703, - ], + { + "input_1": "86584e9afffffff", + "input_2": "8658412c7ffffff", + "schema": None, + "output": [ + "86584e9afffffff", + "86584e91fffffff", + "86584e907ffffff", + "86584e92fffffff", + "8658412d7ffffff", + "8658412c7ffffff", + ], + }, id="valid_path_string", ), pytest.param( - [605035864166236159], - [0], - {"h3_cell_1": pl.UInt64, "h3_cell_2": pl.UInt64}, - None, + { + "input_1": 605035864166236159, + "input_2": 0, + "schema": {"input_1": pl.UInt64, "input_2": pl.UInt64}, + "output": None, + }, id="invalid_path_uint64_to_zero", ), pytest.param( - [605035864166236159], - [0], - {"h3_cell_1": pl.Int64, "h3_cell_2": pl.Int64}, - None, + { + "input_1": 605035864166236159, + "input_2": 0, + "schema": {"input_1": pl.Int64, "input_2": pl.Int64}, + "output": None, + }, id="invalid_path_int64_to_zero", ), pytest.param( - ["86584e9afffffff"], - ["0"], - None, - None, + { + "input_1": "86584e9afffffff", + "input_2": "0", + "schema": None, + "output": None, + }, id="invalid_path_string_to_zero", ), pytest.param( - ["0"], - ["86584e9afffffff"], - None, - None, + { + "input_1": "0", + "input_2": "86584e9afffffff", + "schema": None, + "output": None, + }, id="invalid_path_zero_to_string", ), ], ) -def test_grid_path_cells( - h3_cell_1: List[Union[int, str]], - h3_cell_2: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_path: List[Union[int, str, None]], -): +def test_grid_path_cells(test_params): df = pl.DataFrame( { - "h3_cell_1": h3_cell_1, - "h3_cell_2": h3_cell_2, + "input_1": [test_params["input_1"]], + "input_2": [test_params["input_2"]], }, - schema=schema, + schema=test_params["schema"], ).with_columns( - polars_h3.grid_path_cells("h3_cell_1", "h3_cell_2").list.sort().alias("path") + polars_h3.grid_path_cells("input_1", "input_2").list.sort().alias("path") ) - sorted_expected_path = sorted(expected_path) if expected_path else None - assert df["path"].to_list()[0] == sorted_expected_path - - -def test_grid_distance(): - # string - df = pl.DataFrame( - {"h3_cell_1": ["86584e9afffffff"], "h3_cell_2": ["8658412c7ffffff"]} - ).with_columns(polars_h3.grid_distance("h3_cell_1", "h3_cell_2").alias("distance")) - assert df["distance"].to_list()[0] == 5 - - # unsigned - df = pl.DataFrame( - { - "h3_cell_1": [605035864166236159], - "h3_cell_2": [605034941150920703], - }, - schema={"h3_cell_1": pl.UInt64, "h3_cell_2": pl.UInt64}, - ).with_columns(polars_h3.grid_distance("h3_cell_1", "h3_cell_2").alias("distance")) - assert df["distance"].to_list()[0] == 5 - - # signed - df = pl.DataFrame( - { - "h3_cell_1": [605035864166236159], - "h3_cell_2": [605034941150920703], - }, - schema={"h3_cell_1": pl.Int64, "h3_cell_2": pl.Int64}, - ).with_columns(polars_h3.grid_distance("h3_cell_1", "h3_cell_2").alias("distance")) - assert df["distance"].to_list()[0] == 5 - - # signed to 0 - df = pl.DataFrame( - { - "h3_cell_1": [605035864166236159], - "h3_cell_2": [0], - }, - schema={"h3_cell_1": pl.Int64, "h3_cell_2": pl.Int64}, - ).with_columns(polars_h3.grid_distance("h3_cell_1", "h3_cell_2").alias("distance")) - assert df["distance"].to_list()[0] is None - - # unsigned to 0 - df = pl.DataFrame( - { - "h3_cell_1": [605035864166236159], - "h3_cell_2": [0], - }, - schema={"h3_cell_1": pl.UInt64, "h3_cell_2": pl.UInt64}, - ).with_columns(polars_h3.grid_distance("h3_cell_1", "h3_cell_2").alias("distance")) - assert df["distance"].to_list()[0] is None - - # utf8 - df = pl.DataFrame( - { - "h3_cell_1": ["86584e9afffffff"], - "h3_cell_2": ["0"], - }, - ).with_columns(polars_h3.grid_distance("h3_cell_1", "h3_cell_2").alias("distance")) - assert df["distance"].to_list()[0] is None + sorted_output = sorted(test_params["output"]) if test_params["output"] else None + assert df["path"].to_list()[0] == sorted_output - # different resolutions +@pytest.mark.parametrize( + "test_params", + [ + pytest.param( + { + "input_1": "86584e9afffffff", + "input_2": "8658412c7ffffff", + "schema": None, + "output": 5, + }, + id="string_valid_distance", + ), + pytest.param( + { + "input_1": 605035864166236159, + "input_2": 605034941150920703, + "schema": {"input_1": pl.UInt64, "input_2": pl.UInt64}, + "output": 5, + }, + id="uint64_valid_distance", + ), + pytest.param( + { + "input_1": 605035864166236159, + "input_2": 605034941150920703, + "schema": {"input_1": pl.Int64, "input_2": pl.Int64}, + "output": 5, + }, + id="int64_valid_distance", + ), + pytest.param( + { + "input_1": 605035864166236159, + "input_2": 0, + "schema": {"input_1": pl.Int64, "input_2": pl.Int64}, + "output": None, + }, + id="int64_to_zero", + ), + pytest.param( + { + "input_1": 605035864166236159, + "input_2": 0, + "schema": {"input_1": pl.UInt64, "input_2": pl.UInt64}, + "output": None, + }, + id="uint64_to_zero", + ), + pytest.param( + { + "input_1": "86584e9afffffff", + "input_2": "0", + "schema": None, + "output": None, + }, + id="string_to_zero", + ), + pytest.param( + { + "input_1": "872a10006ffffff", + "input_2": "862a108dfffffff", + "schema": None, + "output": None, + }, + id="different_resolutions", + ), + ], +) +def test_grid_distance(test_params): df = pl.DataFrame( { - "h3_cell_1": ["872a10006ffffff"], - "h3_cell_2": ["862a108dfffffff"], + "input_1": [test_params["input_1"]], + "input_2": [test_params["input_2"]], }, - ).with_columns(polars_h3.grid_distance("h3_cell_1", "h3_cell_2").alias("distance")) - assert df["distance"].to_list()[0] is None + schema=test_params["schema"], + ).with_columns(polars_h3.grid_distance("input_1", "input_2").alias("distance")) + assert df["distance"].to_list()[0] == test_params["output"] @pytest.mark.parametrize( - "origin, dest, schema, expected_coords", + "test_params", [ pytest.param( - [605034941285138431], - [605034941285138431], - {"origin": pl.UInt64, "dest": pl.UInt64}, - [-123, -177], + { + "origin": 605034941285138431, + "dest": 605034941285138431, + "schema": {"origin": pl.UInt64, "dest": pl.UInt64}, + "output": [-123, -177], + }, id="uint64_same_cell", ), pytest.param( - [605034941285138431], - [605034941285138431], - {"origin": pl.Int64, "dest": pl.Int64}, - [-123, -177], + { + "origin": 605034941285138431, + "dest": 605034941285138431, + "schema": {"origin": pl.Int64, "dest": pl.Int64}, + "output": [-123, -177], + }, id="int64_same_cell", ), pytest.param( - ["8658412cfffffff"], - ["8658412cfffffff"], - None, - [-123, -177], + { + "origin": "8658412cfffffff", + "dest": "8658412cfffffff", + "schema": None, + "output": [-123, -177], + }, id="string_same_cell", ), pytest.param( - [605034941285138431], - [0], - {"origin": pl.UInt64, "dest": pl.UInt64}, - None, + { + "origin": 605034941285138431, + "dest": 0, + "schema": {"origin": pl.UInt64, "dest": pl.UInt64}, + "output": None, + }, id="uint64_to_zero", ), pytest.param( - [605034941285138431], - [0], - {"origin": pl.Int64, "dest": pl.Int64}, - None, + { + "origin": 605034941285138431, + "dest": 0, + "schema": {"origin": pl.Int64, "dest": pl.Int64}, + "output": None, + }, id="int64_to_zero", ), - pytest.param(["8658412cfffffff"], ["0"], None, None, id="string_to_zero"), - pytest.param(["8658412cfffffff"], ["abc"], None, None, id="string_to_invalid"), + pytest.param( + { + "origin": "8658412cfffffff", + "dest": "0", + "schema": None, + "output": None, + }, + id="string_to_zero", + ), + pytest.param( + { + "origin": "8658412cfffffff", + "dest": "abc", + "schema": None, + "output": None, + }, + id="string_to_invalid", + ), ], ) -def test_cell_to_local_ij( - origin: List[Union[int, str]], - dest: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_coords: Union[List[int], None], -): +def test_cell_to_local_ij(test_params): df = pl.DataFrame( - {"origin": origin, "dest": dest}, - schema=schema, + { + "origin": [test_params["origin"]], + "dest": [test_params["dest"]], + }, + schema=test_params["schema"], ).with_columns(coords=polars_h3.cell_to_local_ij("origin", "dest")) - assert df["coords"].to_list()[0] == expected_coords + + assert df["coords"].to_list()[0] == test_params["output"] @pytest.mark.parametrize( - "origin, i, j, schema, expected_cell", + "test_params", [ pytest.param( - [605034941285138431], - -123, - -177, - {"origin": pl.UInt64}, - 605034941285138431, + { + "input": 605034941285138431, + "i": -123, + "j": -177, + "schema": {"origin": pl.UInt64}, + "output": 605034941285138431, + }, id="uint64_valid", ), pytest.param( - [605034941285138431], - -123, - -177, - {"origin": pl.Int64}, - 605034941285138431, + { + "input": 605034941285138431, + "i": -123, + "j": -177, + "schema": {"origin": pl.Int64}, + "output": 605034941285138431, + }, id="int64_valid", ), pytest.param( - ["8658412cfffffff"], -123, -177, None, 605034941285138431, id="string_valid" + { + "input": "8658412cfffffff", + "i": -123, + "j": -177, + "schema": None, + "output": 605034941285138431, + }, + id="string_valid", ), pytest.param( - [605034941285138431], - -1230000, - -177, - {"origin": pl.UInt64}, - None, + { + "input": 605034941285138431, + "i": -1230000, + "j": -177, + "schema": {"origin": pl.UInt64}, + "output": None, + }, id="uint64_invalid_coords", ), pytest.param( - [605034941285138431], - -1230000, - -177, - {"origin": pl.Int64}, - None, + { + "input": 605034941285138431, + "i": -1230000, + "j": -177, + "schema": {"origin": pl.Int64}, + "output": None, + }, id="int64_invalid_coords", ), pytest.param( - ["8658412cfffffff"], -1230000, -177, None, None, id="string_invalid_coords" + { + "input": "8658412cfffffff", + "i": -1230000, + "j": -177, + "schema": None, + "output": None, + }, + id="string_invalid_coords", ), ], ) -def test_local_ij_to_cell( - origin: List[Union[int, str]], - i: int, - j: int, - schema: Union[Dict[str, pl.DataType], None], - expected_cell: Union[int, str, None], -): - df = pl.DataFrame({"origin": origin}, schema=schema).with_columns( - cell=polars_h3.local_ij_to_cell("origin", i, j) +def test_local_ij_to_cell(test_params): + df = pl.DataFrame( + {"origin": [test_params["input"]]}, + schema=test_params["schema"], + ).with_columns( + cell=polars_h3.local_ij_to_cell("origin", test_params["i"], test_params["j"]) ) - assert df["cell"].to_list()[0] == expected_cell + + assert df["cell"].to_list()[0] == test_params["output"] diff --git a/tests/test_vertexes.py b/tests/test_vertexes.py index e3f9cfc..fdf63ae 100644 --- a/tests/test_vertexes.py +++ b/tests/test_vertexes.py @@ -1,7 +1,9 @@ -import pytest +from typing import Union + import polars as pl +import pytest + import polars_h3 -from typing import List, Union, Dict @pytest.mark.parametrize( @@ -16,8 +18,8 @@ ], ) def test_is_valid_vertex( - vertex: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], + vertex: list[Union[int, str]], + schema: Union[dict[str, pl.DataType], None], expected_valid: bool, ): df = pl.DataFrame({"vertex": vertex}, schema=schema).with_columns( @@ -63,9 +65,9 @@ def test_is_valid_vertex( ], ) def test_cell_to_vertexes( - h3_cell: List[Union[int, str]], - schema: Union[Dict[str, pl.DataType], None], - expected_vertexes: Union[List[int], None], + h3_cell: list[Union[int, str]], + schema: Union[dict[str, pl.DataType], None], + expected_vertexes: Union[list[int], None], ): df = pl.DataFrame({"h3_cell": h3_cell}, schema=schema).with_columns( vertexes=polars_h3.cell_to_vertexes("h3_cell")