Skip to content

Commit

Permalink
add rep_time support
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiliu committed Nov 14, 2024
1 parent 6def3b2 commit aaf0833
Show file tree
Hide file tree
Showing 12 changed files with 459 additions and 256 deletions.
23 changes: 0 additions & 23 deletions src/chronify/duckdb/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,3 @@ def unpivot(
)
"""
return duckdb.sql(query)


def join(
left_rel: DuckDBPyRelation,
right_rel: DuckDBPyRelation,
on: list[str],
how: str = "inner",
) -> DuckDBPyRelation:
def get_join_statement(left_df, right_df, keys: list):
stmts = [f"{left_df.alias}.{key}={right_df.alias}.{key}" for key in keys]
return " and ".join(stmts)

def get_select_after_join_statement(left_df, right_df, keys: list):
left_cols = [f"{left_df.alias}.{x}" for x in left_df.columns]
right_cols = [x for x in right_df.columns if x not in keys]
return ", ".join(left_cols + right_cols)

join_stmt = get_join_statement(left_rel, right_rel, on)
select_stmt = get_select_after_join_statement(left_rel, right_rel, on)
query = f"SELECT {select_stmt} from {left_rel.alias} {how.upper()} JOIN {right_rel.alias} ON {join_stmt}"
breakpoint()
# return left_rel.join(right_rel, join_stmt).select(select_stmt)
return duckdb.sql(query)
7 changes: 5 additions & 2 deletions src/chronify/sqlalchemy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def read_database(query: Selectable | str, conn: Connection, schema: TableSchema
df[config.time_column] = pd.to_datetime(df[config.time_column], utc=True)
else:
df[config.time_column] = df[config.time_column].dt.tz_localize("UTC")
elif conn.engine.name == "sqlite" and isinstance(config, DatetimeRange):
if isinstance(df[config.time_column].dtype, ObjectDType):
df[config.time_column] = pd.to_datetime(df[config.time_column], utc=False)
return df


Expand All @@ -26,8 +29,8 @@ def write_database(df: pd.DataFrame, conn: Connection, schema: TableSchema) -> N
config = schema.time_config
if config.needs_utc_conversion(conn.engine.name):
assert isinstance(config, DatetimeRange)
if isinstance(df.timestamp.dtype, DatetimeTZDtype):
if isinstance(df[config.time_column].dtype, DatetimeTZDtype):
df[config.time_column] = df[config.time_column].dt.tz_convert("UTC")
else:
df[config.time_column] = df[config.time_column].dt.tz_localize("UTC")
pl.DataFrame(df).write_database(schema.name, connection=conn, if_table_exists="append")
pl.DataFrame(df).write_database(schema.name, connection=conn, if_table_exists="replace")
25 changes: 23 additions & 2 deletions src/chronify/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ class RepresentativePeriodFormat(StrEnum):
)


representative_weekday_column = {
RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR: "day_of_week",
RepresentativePeriodFormat.ONE_WEEKDAY_DAY_AND_ONE_WEEKEND_DAY_PER_MONTH_BY_HOUR: "is_weekday",
}

representative_period_columns = {
RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR: [
"month",
representative_weekday_column[RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR],
"hour",
],
RepresentativePeriodFormat.ONE_WEEKDAY_DAY_AND_ONE_WEEKEND_DAY_PER_MONTH_BY_HOUR: [
"month",
representative_weekday_column[
RepresentativePeriodFormat.ONE_WEEKDAY_DAY_AND_ONE_WEEKEND_DAY_PER_MONTH_BY_HOUR
],
"hour",
],
}


class LeapDayAdjustmentType(StrEnum):
"""Leap day adjustment enum types"""

Expand Down Expand Up @@ -149,7 +170,7 @@ def get_time_zone_offset(tz: TimeZone) -> str:
return offset


def get_standard_timezone(tz: TimeZone) -> TimeZone:
def get_standard_time(tz: TimeZone) -> TimeZone:
"""Return the equivalent standard time zone."""
match tz:
case TimeZone.UTC:
Expand All @@ -173,7 +194,7 @@ def get_standard_timezone(tz: TimeZone) -> TimeZone:
raise NotImplementedError(msg)


def get_prevailing_timezone(tz: TimeZone) -> TimeZone:
def get_prevailing_time(tz: TimeZone) -> TimeZone:
"""Return the equivalent prevailing time zone."""
match tz:
case TimeZone.UTC:
Expand Down
108 changes: 86 additions & 22 deletions src/chronify/time_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import pandas as pd
from pydantic import (
Field,
field_validator,
)
from typing_extensions import Annotated

from chronify.base_models import ChronifyBaseModel
from chronify.exceptions import InvalidParameter
from chronify.time import (
DatetimeFormat,
DaylightSavingFallBackType,
Expand All @@ -23,6 +21,8 @@
TimeIntervalType,
TimeType,
TimeZone,
RepresentativePeriodFormat,
representative_period_columns,
)

# from chronify.time_utils import (
Expand Down Expand Up @@ -131,8 +131,6 @@ class TimeBasedDataAdjustment(ChronifyBaseModel):
class TimeBaseModel(ChronifyBaseModel, abc.ABC):
"""Defines a base model common to all time dimensions."""

length: int

def list_timestamps(self) -> list[Any]:
"""Return a list of timestamps for a time range.
Type of the timestamps depends on the class.
Expand Down Expand Up @@ -177,6 +175,7 @@ class DatetimeRange(TimeBaseModel):
description="Start time of the range. If it includes a time zone, the timestamps in "
"the data must also include time zones."
)
length: int
resolution: timedelta
time_based_data_adjustment: TimeBasedDataAdjustment = TimeBasedDataAdjustment()
interval_type: TimeIntervalType = TimeIntervalType.PERIOD_ENDING
Expand Down Expand Up @@ -235,6 +234,8 @@ class AnnualTimeRange(TimeBaseModel):
time_column: str = Field(description="Column in the table that represents time.")
time_type: Literal[TimeType.ANNUAL] = TimeType.ANNUAL
start: int
length: int
measurement_type: MeasurementType = MeasurementType.TOTAL
# TODO: measurement_type must be TOTAL

def iter_timestamps(self) -> Generator[int, None, None]:
Expand All @@ -248,11 +249,12 @@ def list_time_columns(self) -> list[str]:
class IndexTimeRange(TimeBaseModel):
time_type: Literal[TimeType.INDEX] = TimeType.INDEX
start: int
length: int
resolution: timedelta
time_zone: TimeZone
time_based_data_adjustment: TimeBasedDataAdjustment
interval_type: TimeIntervalType
measurement_type: MeasurementType
interval_type: TimeIntervalType = TimeIntervalType.PERIOD_ENDING
measurement_type: MeasurementType = MeasurementType.TOTAL

# TODO DT: totally wrong
# def iter_timestamps(self) -> Generator[datetime, None, None]:
Expand Down Expand Up @@ -293,25 +295,87 @@ class IndexTimeRange(TimeBaseModel):
class RepresentativePeriodTimeRange(TimeBaseModel):
"""Defines a representative time dimension."""

time_columns: list[str] = Field(description="Columns in the table that represent time.")
time_type: Literal[TimeType.REPRESENTATIVE_PERIOD] = TimeType.REPRESENTATIVE_PERIOD
measurement_type: MeasurementType
time_interval_type: TimeIntervalType
# TODO

@field_validator("time_columns")
@classmethod
def check_columns(cls, columns: list[str]) -> list[str]:
type_1_columns = {"month", "day_of_week", "hour"}
type_2_columns = {"month", "is_weekday", "hour"}
if set(columns) != type_1_columns:
if set(columns) != type_2_columns:
msg = f"Unsupported {columns} for RepresentativePeriodTimeRange, expecting either {type_1_columns} or {type_2_columns}"
raise InvalidParameter(msg)
return columns
time_format: RepresentativePeriodFormat
# time_columns: list[str] = Field(description="Columns in the table that represent time.")
measurement_type: MeasurementType = MeasurementType.TOTAL
interval_type: TimeIntervalType = TimeIntervalType.PERIOD_ENDING

# @model_validator(mode="after")
# def check_columns(self) -> "RepresentativePeriodTimeRange":
# expected = representative_period_columns[self.time_format]

# if set(self.time_columns) != set(expected):
# msg = f"Incorrect {self.time_columns=} for {self.time_format=}, {expected=}"
# raise InvalidParameter(msg)
# return self

def list_time_columns(self) -> list[str]:
return self.time_columns
match self.time_format:
case RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR:
return OneWeekPerMonthByHourHandler().list_time_columns()
case RepresentativePeriodFormat.ONE_WEEKDAY_DAY_AND_ONE_WEEKEND_DAY_PER_MONTH_BY_HOUR:
return OneWeekdayDayAndWeekendDayPerMonthByHourHandler().list_time_columns()

def iter_timestamps(self) -> Generator[int, None, None]:
match self.time_format:
case RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR:
return OneWeekPerMonthByHourHandler().iter_timestamps()
case RepresentativePeriodFormat.ONE_WEEKDAY_DAY_AND_ONE_WEEKEND_DAY_PER_MONTH_BY_HOUR:
return OneWeekdayDayAndWeekendDayPerMonthByHourHandler().iter_timestamps()

def list_timestamps_from_dataframe(self, df: pd.DataFrame) -> list[Any]:
return df[self.list_time_columns()].drop_duplicates().apply(tuple, axis=1).to_list()


class RepresentativeTimeFormatHandlerBase(abc.ABC):
"""Provides implementations for different representative time formats."""

@staticmethod
@abc.abstractmethod
def list_time_columns() -> list[str]:
"""Return the columns in the table that represent time."""

@staticmethod
@abc.abstractmethod
def iter_timestamps() -> Generator[Any, None, None]:
"""Return an iterator over all time indexes in the table.
Type of the time is dependent on the class.
"""


class OneWeekPerMonthByHourHandler(RepresentativeTimeFormatHandlerBase):
"""Handler for format with hourly data that includes one week per month."""

@staticmethod
def list_time_columns() -> list[str]:
return representative_period_columns[RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR]

@staticmethod
def iter_timestamps() -> Generator[Any, None, None]:
for month in range(1, 13):
for dow in range(7):
for hour in range(24):
yield (month, dow, hour)


class OneWeekdayDayAndWeekendDayPerMonthByHourHandler(RepresentativeTimeFormatHandlerBase):
"""Handler for format with hourly data that includes one weekday day and one weekend day
per month.
"""

@staticmethod
def list_time_columns() -> list[str]:
return representative_period_columns[
RepresentativePeriodFormat.ONE_WEEKDAY_DAY_AND_ONE_WEEKEND_DAY_PER_MONTH_BY_HOUR
]

@staticmethod
def iter_timestamps() -> Generator[Any, None, None]:
for month in range(1, 13):
for is_weekday in sorted([False, True]):
for hour in range(24):
yield (month, is_weekday, hour)


TimeConfig = Annotated[
Expand Down
5 changes: 3 additions & 2 deletions src/chronify/time_series_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _run_timestamp_checks_on_tmp_table(self, table_name: str) -> None:
result3 = self._conn.execute(text(query3)).fetchone()
assert result3 is not None
actual_count = result3[0]
if actual_count != self._schema.time_config.length:
msg = f"Time arrays must have length={self._schema.time_config.length}. Actual = {actual_count}"
expected_count = len(schema.time_config.list_timestamps())
if actual_count != expected_count:
msg = f"Time arrays must have length={expected_count}. Actual = {actual_count}"
raise InvalidTable(msg)
Loading

0 comments on commit aaf0833

Please sign in to comment.