Skip to content

Commit

Permalink
applies Black linter
Browse files Browse the repository at this point in the history
Signed-off-by: Timm638 <[email protected]>
  • Loading branch information
Timm638 committed Dec 17, 2024
1 parent b960d1d commit 1019bfa
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 111 deletions.
23 changes: 18 additions & 5 deletions src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def get_dbutils(
# def onQueryTerminated(self, event):
# logging.info("Query terminated: {} {}".format(event.id, event.name))

def is_dataframe_partially_conformed_in_schema(dataframe: DataFrame, schema: StructType, throw_error: bool = True) -> bool:

def is_dataframe_partially_conformed_in_schema(
dataframe: DataFrame, schema: StructType, throw_error: bool = True
) -> bool:
"""
Returns true if all columns in the dataframe are contained in the scheme with appropriate type
"""
Expand All @@ -132,18 +135,26 @@ def is_dataframe_partially_conformed_in_schema(dataframe: DataFrame, schema: Str
if not throw_error:
return False
else:
raise ValueError("Column {0} is of Type {1}, expected Type {2}".format(column, column.dataType, schema_field.dataType))
raise ValueError(
"Column {0} is of Type {1}, expected Type {2}".format(
column, column.dataType, schema_field.dataType
)
)

else:
# dataframe contains column not expected ins schema
if not throw_error:
return False
else:
raise ValueError("Column {0} is not expected in dataframe".format(column))
raise ValueError(
"Column {0} is not expected in dataframe".format(column)
)
return True


def conform_dataframe_to_schema(dataframe: DataFrame, schema: StructType, throw_error: bool = True) -> DataFrame:
def conform_dataframe_to_schema(
dataframe: DataFrame, schema: StructType, throw_error: bool = True
) -> DataFrame:
"""
Tries to convert all commong to the given scheme
Raises Exception on non-conforming dataframe
Expand All @@ -155,7 +166,9 @@ def conform_dataframe_to_schema(dataframe: DataFrame, schema: StructType, throw_
if isinstance(column.dataType, type(schema_field.dataType)):
# column is of right type, skip it
continue
dataframe = dataframe.withColumn(c_name, dataframe[c_name].cast(schema_field.dataType))
dataframe = dataframe.withColumn(
c_name, dataframe[c_name].cast(schema_field.dataType)
)
else:
raise ValueError("Column {0} is not expected in dataframe".format(column))
return dataframe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

import pandas as pd
from pandas import DataFrame
from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession, functions as F, DataFrame as SparkDataFrame
from pyspark.sql import (
DataFrame as PySparkDataFrame,
SparkSession,
functions as F,
DataFrame as SparkDataFrame,
)
from pyspark.sql.functions import col, lit
from pyspark.sql.types import StringType, StructField, StructType
from regex import regex
Expand Down Expand Up @@ -125,15 +130,16 @@ class InputStyle(Enum):
"""
Used to describe style of a dataframe
"""
COLUMN_BASED = 1 # Schema: [EventTime, FirstSource, SecondSource, ...]
SOURCE_BASED = 2 # Schema: [EventTime, NameSource, Value, OptionalStatus]

COLUMN_BASED = 1 # Schema: [EventTime, FirstSource, SecondSource, ...]
SOURCE_BASED = 2 # Schema: [EventTime, NameSource, Value, OptionalStatus]

def __init__(
self,
past_data: PySparkDataFrame,
to_extend_name: str, # either source or column
# Metadata about past_date
past_data_style : InputStyle = None,
past_data_style: InputStyle = None,
value_name: str = None,
timestamp_name: str = None,
source_name: str = None,
Expand All @@ -144,20 +150,29 @@ def __init__(
number_of_data_points_to_analyze: int = None,
order: tuple = (0, 0, 0),
seasonal_order: tuple = (0, 0, 0, 0),
trend= None,
trend=None,
enforce_stationarity: bool = True,
enforce_invertibility: bool = True,
concentrate_scale: bool = False,
trend_offset: int = 1,
missing: str = "None"
missing: str = "None",
) -> None:
self.past_data = past_data
# Convert dataframe to general column-based format for internal processing
self._initialize_self_df(past_data, past_data_style, source_name, status_name, timestamp_name, to_extend_name,
value_name)
self._initialize_self_df(
past_data,
past_data_style,
source_name,
status_name,
timestamp_name,
to_extend_name,
value_name,
)

if number_of_data_points_to_analyze > self.df.count():
raise ValueError("Number of data points to analyze exceeds the number of rows present")
raise ValueError(
"Number of data points to analyze exceeds the number of rows present"
)

self.spark_session = past_data.sparkSession
self.column_to_predict = to_extend_name
Expand Down Expand Up @@ -199,22 +214,53 @@ def _is_column_type(df, column_name, data_type):

return isinstance(type_.dataType, data_type)

def _initialize_self_df(self, past_data, past_data_style, source_name, status_name, timestamp_name, to_extend_name,
value_name):
#Initialize self.df with meta parameters if not already done by previous constructor
def _initialize_self_df(
self,
past_data,
past_data_style,
source_name,
status_name,
timestamp_name,
to_extend_name,
value_name,
):
# Initialize self.df with meta parameters if not already done by previous constructor
if self.df is None:
self.past_data_style, self.value_name, self.timestamp_name, self.source_name, self.status_name = self._constructor_handle_input_metadata(
past_data, past_data_style, value_name, timestamp_name, source_name, status_name)
(
self.past_data_style,
self.value_name,
self.timestamp_name,
self.source_name,
self.status_name,
) = self._constructor_handle_input_metadata(
past_data,
past_data_style,
value_name,
timestamp_name,
source_name,
status_name,
)

if self.past_data_style == self.InputStyle.COLUMN_BASED:
self.df = past_data
elif self.past_data_style == self.InputStyle.SOURCE_BASED:
self.df = past_data.groupby(self.timestamp_name).pivot(self.source_name).agg(F.first(self.value_name))
self.df = (
past_data.groupby(self.timestamp_name)
.pivot(self.source_name)
.agg(F.first(self.value_name))
)
if not to_extend_name in self.df.columns:
raise ValueError("{} not found in the DataFrame.".format(value_name))


def _constructor_handle_input_metadata(self, past_data: PySparkDataFrame, past_data_style: InputStyle, value_name: str, timestamp_name: str, source_name:str, status_name: str) -> Tuple[InputStyle, str, str, str, str]:
def _constructor_handle_input_metadata(
self,
past_data: PySparkDataFrame,
past_data_style: InputStyle,
value_name: str,
timestamp_name: str,
source_name: str,
status_name: str,
) -> Tuple[InputStyle, str, str, str, str]:
# Infer names of columns from past_data schema. If nothing is found, leave self parameters at None.
if past_data_style is not None:
return past_data_style, value_name, timestamp_name, source_name, status_name
Expand All @@ -228,7 +274,9 @@ def _constructor_handle_input_metadata(self, past_data: PySparkDataFrame, past_d
source_name = None
status_name = None

def pickout_column(rem_columns: List[str], regex_string: str) -> (str, List[str]):
def pickout_column(
rem_columns: List[str], regex_string: str
) -> (str, List[str]):
rgx = regex.compile(regex_string)
sus_columns = list(filter(rgx.search, rem_columns))
found_column = sus_columns[0] if len(sus_columns) == 1 else None
Expand All @@ -241,7 +289,9 @@ def pickout_column(rem_columns: List[str], regex_string: str) -> (str, List[str]
# Is there a source name / tag
source_name, remaining_columns = pickout_column(schema_names, r"(?i)tag")
# Is there a timestamp column?
timestamp_name, remaining_columns = pickout_column(schema_names, r"(?i)time|index")
timestamp_name, remaining_columns = pickout_column(
schema_names, r"(?i)time|index"
)
# Is there a value column?
value_name, remaining_columns = pickout_column(schema_names, r"(?i)value")

Expand All @@ -250,10 +300,16 @@ def pickout_column(rem_columns: List[str], regex_string: str) -> (str, List[str]
else:
assumed_past_data_style = self.InputStyle.COLUMN_BASED

#if self.past_data_style is None:
# if self.past_data_style is None:
# raise ValueError(
# "Automatic determination of past_data_style failed, must be specified in parameter instead.")
return assumed_past_data_style, value_name, timestamp_name, source_name, status_name
return (
assumed_past_data_style,
value_name,
timestamp_name,
source_name,
status_name,
)

def filter(self) -> PySparkDataFrame:
"""
Expand All @@ -272,18 +328,25 @@ def filter(self) -> PySparkDataFrame:
# StructField("Status", StringType(), True),
# StructField("Value", NumericType(), True),
# ]
#)
# )
pd_df = self.df.toPandas()
pd_df.loc[:, self.timestamp_name] = pd.to_datetime(pd_df[self.timestamp_name], format="mixed").astype(
"datetime64[ns]")
pd_df.loc[:, self.column_to_predict] = pd_df.loc[:, self.column_to_predict].astype(float)
pd_df.loc[:, self.timestamp_name] = pd.to_datetime(
pd_df[self.timestamp_name], format="mixed"
).astype("datetime64[ns]")
pd_df.loc[:, self.column_to_predict] = pd_df.loc[
:, self.column_to_predict
].astype(float)
pd_df.sort_values(self.timestamp_name, inplace=True)
pd_df.reset_index(drop=True, inplace=True)
# self.validate(expected_scheme)

# limit df to specific data points
pd_to_train_on = pd_df[pd_df[self.column_to_predict].notna()].tail(self.rows_to_analyze)
pd_to_predict_on = pd_df[pd_df[self.column_to_predict].isna()].head(self.rows_to_predict)
pd_to_train_on = pd_df[pd_df[self.column_to_predict].notna()].tail(
self.rows_to_analyze
)
pd_to_predict_on = pd_df[pd_df[self.column_to_predict].isna()].head(
self.rows_to_predict
)
pd_df = pd.concat([pd_to_train_on, pd_to_predict_on])

main_signal_df = pd_df[pd_df[self.column_to_predict].notna()]
Expand All @@ -296,7 +359,6 @@ def filter(self) -> PySparkDataFrame:
# signal_df = pd.concat([pd_to_train_on[column_name], pd_to_predict_on[column_name]])
# exog_data.append(signal_df)


source_model = ARIMA(
endog=input_data,
exog=exog_data,
Expand All @@ -311,42 +373,55 @@ def filter(self) -> PySparkDataFrame:
).fit()

forecast = source_model.forecast(steps=self.rows_to_predict)
inferred_freq = pd.Timedelta(value=statistics.mode(np.diff(main_signal_df[self.timestamp_name].values)))
inferred_freq = pd.Timedelta(
value=statistics.mode(np.diff(main_signal_df[self.timestamp_name].values))
)

pd_forecast_df = pd.DataFrame(
{
self.timestamp_name: pd.date_range(start=main_signal_df[self.timestamp_name].max() + inferred_freq, periods=self.rows_to_predict, freq=inferred_freq),
self.column_to_predict: forecast
}
)
{
self.timestamp_name: pd.date_range(
start=main_signal_df[self.timestamp_name].max() + inferred_freq,
periods=self.rows_to_predict,
freq=inferred_freq,
),
self.column_to_predict: forecast,
}
)

pd_df = pd.concat([pd_df, pd_forecast_df])



if self.past_data_style == self.InputStyle.COLUMN_BASED:
for obj in self.past_data.schema:
simple_string_type = obj.dataType.simpleString()
if simple_string_type == 'timestamp':
if simple_string_type == "timestamp":
continue
pd_df.loc[:, obj.name] = pd_df.loc[:, obj.name].astype(simple_string_type)
pd_df.loc[:, obj.name] = pd_df.loc[:, obj.name].astype(
simple_string_type
)
# Workaround needed for PySpark versions <3.4
pd_df = _prepare_pandas_to_convert_to_spark(pd_df)
predicted_source_pyspark_dataframe = self.spark_session.createDataFrame(
pd_df, schema=self.past_data.schema
)
return predicted_source_pyspark_dataframe
elif self.past_data_style == self.InputStyle.SOURCE_BASED:
data_to_add = pd_forecast_df[[self.column_to_predict ,self.timestamp_name]]
data_to_add = data_to_add.rename(columns={self.timestamp_name: self.timestamp_name, self.column_to_predict: self.value_name})
data_to_add = pd_forecast_df[[self.column_to_predict, self.timestamp_name]]
data_to_add = data_to_add.rename(
columns={
self.timestamp_name: self.timestamp_name,
self.column_to_predict: self.value_name,
}
)
data_to_add[self.source_name] = self.column_to_predict
data_to_add[self.timestamp_name] = data_to_add[self.timestamp_name].dt.strftime("%Y-%m-%dT%H:%M:%S.%f")
data_to_add[self.timestamp_name] = data_to_add[
self.timestamp_name
].dt.strftime("%Y-%m-%dT%H:%M:%S.%f")

pd_df_schema = StructType(
[
[
StructField(self.source_name, StringType(), True),
StructField(self.timestamp_name, StringType(), True),
StructField(self.value_name, StringType(), True)
StructField(self.value_name, StringType(), True),
]
)

Expand All @@ -358,7 +433,11 @@ def filter(self) -> PySparkDataFrame:
)

if self.status_name is not None:
predicted_source_pyspark_dataframe = predicted_source_pyspark_dataframe.withColumn(self.status_name, lit("Predicted"))
predicted_source_pyspark_dataframe = (
predicted_source_pyspark_dataframe.withColumn(
self.status_name, lit("Predicted")
)
)

return self.past_data.union(predicted_source_pyspark_dataframe)

Expand Down
Loading

0 comments on commit 1019bfa

Please sign in to comment.