diff --git a/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/spark.py b/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/spark.py index 0c9e804d2..bb9557208 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/spark.py +++ b/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/spark.py @@ -118,7 +118,10 @@ def get_dbutils( # def onQueryTerminated(self, event): # logging.info("Query terminated: {} {}".format(event.id, event.name)) -def is_dataframe_partially_conformed_in_schema(dataframe: DataFrame, schema: StructType, throw_error: bool = True) -> bool: + +def is_dataframe_partially_conformed_in_schema( + dataframe: DataFrame, schema: StructType, throw_error: bool = True +) -> bool: """ Returns true if all columns in the dataframe are contained in the scheme with appropriate type """ @@ -132,18 +135,26 @@ def is_dataframe_partially_conformed_in_schema(dataframe: DataFrame, schema: Str if not throw_error: return False else: - raise ValueError("Column {0} is of Type {1}, expected Type {2}".format(column, column.dataType, schema_field.dataType)) + raise ValueError( + "Column {0} is of Type {1}, expected Type {2}".format( + column, column.dataType, schema_field.dataType + ) + ) else: # dataframe contains column not expected ins schema if not throw_error: return False else: - raise ValueError("Column {0} is not expected in dataframe".format(column)) + raise ValueError( + "Column {0} is not expected in dataframe".format(column) + ) return True -def conform_dataframe_to_schema(dataframe: DataFrame, schema: StructType, throw_error: bool = True) -> DataFrame: +def conform_dataframe_to_schema( + dataframe: DataFrame, schema: StructType, throw_error: bool = True +) -> DataFrame: """ Tries to convert all commong to the given scheme Raises Exception on non-conforming dataframe @@ -155,7 +166,9 @@ def conform_dataframe_to_schema(dataframe: DataFrame, schema: StructType, throw_ if isinstance(column.dataType, type(schema_field.dataType)): # column is of right type, skip it continue - dataframe = dataframe.withColumn(c_name, dataframe[c_name].cast(schema_field.dataType)) + dataframe = dataframe.withColumn( + c_name, dataframe[c_name].cast(schema_field.dataType) + ) else: raise ValueError("Column {0} is not expected in dataframe".format(column)) return dataframe diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/arima.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/arima.py index 6ee98d4a3..ea70120d1 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/arima.py +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/arima.py @@ -17,7 +17,12 @@ import pandas as pd from pandas import DataFrame -from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession, functions as F, DataFrame as SparkDataFrame +from pyspark.sql import ( + DataFrame as PySparkDataFrame, + SparkSession, + functions as F, + DataFrame as SparkDataFrame, +) from pyspark.sql.functions import col, lit from pyspark.sql.types import StringType, StructField, StructType from regex import regex @@ -125,15 +130,16 @@ class InputStyle(Enum): """ Used to describe style of a dataframe """ - COLUMN_BASED = 1 # Schema: [EventTime, FirstSource, SecondSource, ...] - SOURCE_BASED = 2 # Schema: [EventTime, NameSource, Value, OptionalStatus] + + COLUMN_BASED = 1 # Schema: [EventTime, FirstSource, SecondSource, ...] + SOURCE_BASED = 2 # Schema: [EventTime, NameSource, Value, OptionalStatus] def __init__( self, past_data: PySparkDataFrame, to_extend_name: str, # either source or column # Metadata about past_date - past_data_style : InputStyle = None, + past_data_style: InputStyle = None, value_name: str = None, timestamp_name: str = None, source_name: str = None, @@ -144,20 +150,29 @@ def __init__( number_of_data_points_to_analyze: int = None, order: tuple = (0, 0, 0), seasonal_order: tuple = (0, 0, 0, 0), - trend= None, + trend=None, enforce_stationarity: bool = True, enforce_invertibility: bool = True, concentrate_scale: bool = False, trend_offset: int = 1, - missing: str = "None" + missing: str = "None", ) -> None: self.past_data = past_data # Convert dataframe to general column-based format for internal processing - self._initialize_self_df(past_data, past_data_style, source_name, status_name, timestamp_name, to_extend_name, - value_name) + self._initialize_self_df( + past_data, + past_data_style, + source_name, + status_name, + timestamp_name, + to_extend_name, + value_name, + ) if number_of_data_points_to_analyze > self.df.count(): - raise ValueError("Number of data points to analyze exceeds the number of rows present") + raise ValueError( + "Number of data points to analyze exceeds the number of rows present" + ) self.spark_session = past_data.sparkSession self.column_to_predict = to_extend_name @@ -199,22 +214,53 @@ def _is_column_type(df, column_name, data_type): return isinstance(type_.dataType, data_type) - def _initialize_self_df(self, past_data, past_data_style, source_name, status_name, timestamp_name, to_extend_name, - value_name): - #Initialize self.df with meta parameters if not already done by previous constructor + def _initialize_self_df( + self, + past_data, + past_data_style, + source_name, + status_name, + timestamp_name, + to_extend_name, + value_name, + ): + # Initialize self.df with meta parameters if not already done by previous constructor if self.df is None: - self.past_data_style, self.value_name, self.timestamp_name, self.source_name, self.status_name = self._constructor_handle_input_metadata( - past_data, past_data_style, value_name, timestamp_name, source_name, status_name) + ( + self.past_data_style, + self.value_name, + self.timestamp_name, + self.source_name, + self.status_name, + ) = self._constructor_handle_input_metadata( + past_data, + past_data_style, + value_name, + timestamp_name, + source_name, + status_name, + ) if self.past_data_style == self.InputStyle.COLUMN_BASED: self.df = past_data elif self.past_data_style == self.InputStyle.SOURCE_BASED: - self.df = past_data.groupby(self.timestamp_name).pivot(self.source_name).agg(F.first(self.value_name)) + self.df = ( + past_data.groupby(self.timestamp_name) + .pivot(self.source_name) + .agg(F.first(self.value_name)) + ) if not to_extend_name in self.df.columns: raise ValueError("{} not found in the DataFrame.".format(value_name)) - - def _constructor_handle_input_metadata(self, past_data: PySparkDataFrame, past_data_style: InputStyle, value_name: str, timestamp_name: str, source_name:str, status_name: str) -> Tuple[InputStyle, str, str, str, str]: + def _constructor_handle_input_metadata( + self, + past_data: PySparkDataFrame, + past_data_style: InputStyle, + value_name: str, + timestamp_name: str, + source_name: str, + status_name: str, + ) -> Tuple[InputStyle, str, str, str, str]: # Infer names of columns from past_data schema. If nothing is found, leave self parameters at None. if past_data_style is not None: return past_data_style, value_name, timestamp_name, source_name, status_name @@ -228,7 +274,9 @@ def _constructor_handle_input_metadata(self, past_data: PySparkDataFrame, past_d source_name = None status_name = None - def pickout_column(rem_columns: List[str], regex_string: str) -> (str, List[str]): + def pickout_column( + rem_columns: List[str], regex_string: str + ) -> (str, List[str]): rgx = regex.compile(regex_string) sus_columns = list(filter(rgx.search, rem_columns)) found_column = sus_columns[0] if len(sus_columns) == 1 else None @@ -241,7 +289,9 @@ def pickout_column(rem_columns: List[str], regex_string: str) -> (str, List[str] # Is there a source name / tag source_name, remaining_columns = pickout_column(schema_names, r"(?i)tag") # Is there a timestamp column? - timestamp_name, remaining_columns = pickout_column(schema_names, r"(?i)time|index") + timestamp_name, remaining_columns = pickout_column( + schema_names, r"(?i)time|index" + ) # Is there a value column? value_name, remaining_columns = pickout_column(schema_names, r"(?i)value") @@ -250,10 +300,16 @@ def pickout_column(rem_columns: List[str], regex_string: str) -> (str, List[str] else: assumed_past_data_style = self.InputStyle.COLUMN_BASED - #if self.past_data_style is None: + # if self.past_data_style is None: # raise ValueError( # "Automatic determination of past_data_style failed, must be specified in parameter instead.") - return assumed_past_data_style, value_name, timestamp_name, source_name, status_name + return ( + assumed_past_data_style, + value_name, + timestamp_name, + source_name, + status_name, + ) def filter(self) -> PySparkDataFrame: """ @@ -272,18 +328,25 @@ def filter(self) -> PySparkDataFrame: # StructField("Status", StringType(), True), # StructField("Value", NumericType(), True), # ] - #) + # ) pd_df = self.df.toPandas() - pd_df.loc[:, self.timestamp_name] = pd.to_datetime(pd_df[self.timestamp_name], format="mixed").astype( - "datetime64[ns]") - pd_df.loc[:, self.column_to_predict] = pd_df.loc[:, self.column_to_predict].astype(float) + pd_df.loc[:, self.timestamp_name] = pd.to_datetime( + pd_df[self.timestamp_name], format="mixed" + ).astype("datetime64[ns]") + pd_df.loc[:, self.column_to_predict] = pd_df.loc[ + :, self.column_to_predict + ].astype(float) pd_df.sort_values(self.timestamp_name, inplace=True) pd_df.reset_index(drop=True, inplace=True) # self.validate(expected_scheme) # limit df to specific data points - pd_to_train_on = pd_df[pd_df[self.column_to_predict].notna()].tail(self.rows_to_analyze) - pd_to_predict_on = pd_df[pd_df[self.column_to_predict].isna()].head(self.rows_to_predict) + pd_to_train_on = pd_df[pd_df[self.column_to_predict].notna()].tail( + self.rows_to_analyze + ) + pd_to_predict_on = pd_df[pd_df[self.column_to_predict].isna()].head( + self.rows_to_predict + ) pd_df = pd.concat([pd_to_train_on, pd_to_predict_on]) main_signal_df = pd_df[pd_df[self.column_to_predict].notna()] @@ -296,7 +359,6 @@ def filter(self) -> PySparkDataFrame: # signal_df = pd.concat([pd_to_train_on[column_name], pd_to_predict_on[column_name]]) # exog_data.append(signal_df) - source_model = ARIMA( endog=input_data, exog=exog_data, @@ -311,25 +373,31 @@ def filter(self) -> PySparkDataFrame: ).fit() forecast = source_model.forecast(steps=self.rows_to_predict) - inferred_freq = pd.Timedelta(value=statistics.mode(np.diff(main_signal_df[self.timestamp_name].values))) + inferred_freq = pd.Timedelta( + value=statistics.mode(np.diff(main_signal_df[self.timestamp_name].values)) + ) pd_forecast_df = pd.DataFrame( - { - self.timestamp_name: pd.date_range(start=main_signal_df[self.timestamp_name].max() + inferred_freq, periods=self.rows_to_predict, freq=inferred_freq), - self.column_to_predict: forecast - } - ) + { + self.timestamp_name: pd.date_range( + start=main_signal_df[self.timestamp_name].max() + inferred_freq, + periods=self.rows_to_predict, + freq=inferred_freq, + ), + self.column_to_predict: forecast, + } + ) pd_df = pd.concat([pd_df, pd_forecast_df]) - - if self.past_data_style == self.InputStyle.COLUMN_BASED: for obj in self.past_data.schema: simple_string_type = obj.dataType.simpleString() - if simple_string_type == 'timestamp': + if simple_string_type == "timestamp": continue - pd_df.loc[:, obj.name] = pd_df.loc[:, obj.name].astype(simple_string_type) + pd_df.loc[:, obj.name] = pd_df.loc[:, obj.name].astype( + simple_string_type + ) # Workaround needed for PySpark versions <3.4 pd_df = _prepare_pandas_to_convert_to_spark(pd_df) predicted_source_pyspark_dataframe = self.spark_session.createDataFrame( @@ -337,16 +405,23 @@ def filter(self) -> PySparkDataFrame: ) return predicted_source_pyspark_dataframe elif self.past_data_style == self.InputStyle.SOURCE_BASED: - data_to_add = pd_forecast_df[[self.column_to_predict ,self.timestamp_name]] - data_to_add = data_to_add.rename(columns={self.timestamp_name: self.timestamp_name, self.column_to_predict: self.value_name}) + data_to_add = pd_forecast_df[[self.column_to_predict, self.timestamp_name]] + data_to_add = data_to_add.rename( + columns={ + self.timestamp_name: self.timestamp_name, + self.column_to_predict: self.value_name, + } + ) data_to_add[self.source_name] = self.column_to_predict - data_to_add[self.timestamp_name] = data_to_add[self.timestamp_name].dt.strftime("%Y-%m-%dT%H:%M:%S.%f") + data_to_add[self.timestamp_name] = data_to_add[ + self.timestamp_name + ].dt.strftime("%Y-%m-%dT%H:%M:%S.%f") pd_df_schema = StructType( - [ + [ StructField(self.source_name, StringType(), True), StructField(self.timestamp_name, StringType(), True), - StructField(self.value_name, StringType(), True) + StructField(self.value_name, StringType(), True), ] ) @@ -358,7 +433,11 @@ def filter(self) -> PySparkDataFrame: ) if self.status_name is not None: - predicted_source_pyspark_dataframe = predicted_source_pyspark_dataframe.withColumn(self.status_name, lit("Predicted")) + predicted_source_pyspark_dataframe = ( + predicted_source_pyspark_dataframe.withColumn( + self.status_name, lit("Predicted") + ) + ) return self.past_data.union(predicted_source_pyspark_dataframe) diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/auto_arima.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/auto_arima.py index 1e3a34851..2f5ef22e6 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/auto_arima.py +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/prediction/auto_arima.py @@ -85,10 +85,11 @@ class ArimaAutoPrediction(ArimaPrediction): trend_offset (int): ARIMA-Specific setting missing (str): ARIMA-Specific setting """ + def __init__( self, past_data: PySparkDataFrame, - past_data_style : ArimaPrediction.InputStyle = None, + past_data_style: ArimaPrediction.InputStyle = None, to_extend_name: str = None, value_name: str = None, timestamp_name: str = None, @@ -105,18 +106,27 @@ def __init__( missing: str = "None", ) -> None: # Convert source-based dataframe to column-based if necessary - self._initialize_self_df(past_data, past_data_style, source_name, status_name, timestamp_name, to_extend_name, - value_name) + self._initialize_self_df( + past_data, + past_data_style, + source_name, + status_name, + timestamp_name, + to_extend_name, + value_name, + ) # Prepare Input data input_data = self.df.toPandas() - input_data = input_data[input_data[to_extend_name].notna()].tail(number_of_data_points_to_analyze)[to_extend_name] + input_data = input_data[input_data[to_extend_name].notna()].tail( + number_of_data_points_to_analyze + )[to_extend_name] auto_model = auto_arima( y=input_data, seasonal=seasonal, stepwise=True, suppress_warnings=True, - trace=False, # Set to true if to debug + trace=False, # Set to true if to debug error_action="ignore", max_order=None, ) @@ -134,11 +144,10 @@ def __init__( number_of_data_points_to_analyze=number_of_data_points_to_analyze, order=auto_model.order, seasonal_order=auto_model.seasonal_order, - trend = "c" if auto_model.order[1] == 0 else "t", + trend="c" if auto_model.order[1] == 0 else "t", enforce_stationarity=enforce_stationarity, enforce_invertibility=enforce_invertibility, concentrate_scale=concentrate_scale, trend_offset=trend_offset, - missing=missing + missing=missing, ) - diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_arima.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_arima.py index 27fdf4b95..8458e1a4d 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_arima.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_arima.py @@ -28,10 +28,10 @@ ) from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.prediction.arima import ( - ArimaPrediction + ArimaPrediction, ) from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.prediction.auto_arima import ( - ArimaAutoPrediction + ArimaAutoPrediction, ) # Testcases to add: @@ -61,6 +61,7 @@ def spark_session(): .getOrCreate() ) + @pytest.fixture(scope="session") def historic_data(): hist_data = [ @@ -204,6 +205,7 @@ def historic_data(): ] return hist_data + @pytest.fixture(scope="session") def source_based_synthetic_data(): output_object = {} @@ -214,25 +216,34 @@ def source_based_synthetic_data(): arr_len = 100 h_a_l = int(arr_len / 2) - df1['Value'] = np.random.rand(arr_len) + np.sin(np.linspace(0, arr_len / 2, num=arr_len)) - df2['Value'] = df1['Value'] * 2 + np.cos(np.linspace(0, arr_len / 2, num=arr_len)) + 5 - df1['index'] = np.asarray(pd.date_range(start='1/1/2024', end='2/1/2024', periods=arr_len)).astype(str) - df2['index'] = np.asarray(pd.date_range(start='1/1/2024', end='2/1/2024', periods=arr_len)).astype(str) - df1['TagName'] = 'PrimarySensor' - df2['TagName'] = 'SecondarySensor' - df1['Status'] = 'Good' - df2['Status'] = 'Good' - - output_object['df1'] = df1 - output_object['df2'] = df2 - output_object['arr_len'] = arr_len - output_object['h_a_l'] = h_a_l - output_object['half_df1_full_df2'] = pd.concat([df1.head(h_a_l), df2]) - output_object['full_df1_full_df2'] = pd.concat([df1, df2]) - output_object['full_df1_half_df2'] = pd.concat([df1, df2.head(h_a_l)]) - output_object['half_df1_half_df2'] = pd.concat([df1.head(h_a_l), df2.head(h_a_l)]) + df1["Value"] = np.random.rand(arr_len) + np.sin( + np.linspace(0, arr_len / 2, num=arr_len) + ) + df2["Value"] = ( + df1["Value"] * 2 + np.cos(np.linspace(0, arr_len / 2, num=arr_len)) + 5 + ) + df1["index"] = np.asarray( + pd.date_range(start="1/1/2024", end="2/1/2024", periods=arr_len) + ).astype(str) + df2["index"] = np.asarray( + pd.date_range(start="1/1/2024", end="2/1/2024", periods=arr_len) + ).astype(str) + df1["TagName"] = "PrimarySensor" + df2["TagName"] = "SecondarySensor" + df1["Status"] = "Good" + df2["Status"] = "Good" + + output_object["df1"] = df1 + output_object["df2"] = df2 + output_object["arr_len"] = arr_len + output_object["h_a_l"] = h_a_l + output_object["half_df1_full_df2"] = pd.concat([df1.head(h_a_l), df2]) + output_object["full_df1_full_df2"] = pd.concat([df1, df2]) + output_object["full_df1_half_df2"] = pd.concat([df1, df2.head(h_a_l)]) + output_object["half_df1_half_df2"] = pd.concat([df1.head(h_a_l), df2.head(h_a_l)]) return output_object + @pytest.fixture(scope="session") def column_based_synthetic_data(): output_object = {} @@ -242,24 +253,31 @@ def column_based_synthetic_data(): arr_len = 100 h_a_l = int(arr_len / 2) - idx_start = '1/1/2024' - idx_end = '2/1/2024' - - df1['PrimarySensor'] = np.random.rand(arr_len) + np.sin(np.linspace(0, arr_len / 2, num=arr_len)) - df1['SecondarySensor'] = df1['PrimarySensor'] * 2 + np.cos(np.linspace(0, arr_len / 2, num=arr_len)) + 5 - df1['index'] = np.asarray(pd.date_range(start=idx_start, end=idx_end, periods=arr_len)).astype(str) - - output_object['df'] = df1 - output_object['arr_len'] = arr_len - output_object['h_a_l'] = h_a_l - output_object['half_df1_full_df2'] = df1.copy() - output_object['half_df1_full_df2'].loc[h_a_l:, 'PrimarySensor'] = None - output_object['full_df1_full_df2'] = df1.copy() - output_object['full_df1_half_df2'] = df1.copy() - output_object['full_df1_half_df2'].loc[h_a_l:, 'SecondarySensor'] = None - output_object['half_df1_half_df2'] = df1.copy().head(h_a_l) + idx_start = "1/1/2024" + idx_end = "2/1/2024" + + df1["PrimarySensor"] = np.random.rand(arr_len) + np.sin( + np.linspace(0, arr_len / 2, num=arr_len) + ) + df1["SecondarySensor"] = ( + df1["PrimarySensor"] * 2 + np.cos(np.linspace(0, arr_len / 2, num=arr_len)) + 5 + ) + df1["index"] = np.asarray( + pd.date_range(start=idx_start, end=idx_end, periods=arr_len) + ).astype(str) + + output_object["df"] = df1 + output_object["arr_len"] = arr_len + output_object["h_a_l"] = h_a_l + output_object["half_df1_full_df2"] = df1.copy() + output_object["half_df1_full_df2"].loc[h_a_l:, "PrimarySensor"] = None + output_object["full_df1_full_df2"] = df1.copy() + output_object["full_df1_half_df2"] = df1.copy() + output_object["full_df1_half_df2"].loc[h_a_l:, "SecondarySensor"] = None + output_object["half_df1_half_df2"] = df1.copy().head(h_a_l) return output_object + def test_nonexistent_column_arima(spark_session: SparkSession): input_df = spark_session.createDataFrame( [ @@ -272,6 +290,7 @@ def test_nonexistent_column_arima(spark_session: SparkSession): with pytest.raises(ValueError): ArimaPrediction(input_df, to_extend_name="NonexistingColumn") + def test_invalid_size_arima(spark_session: SparkSession): input_df = spark_session.createDataFrame( [ @@ -282,7 +301,13 @@ def test_invalid_size_arima(spark_session: SparkSession): ) with pytest.raises(ValueError): - ArimaPrediction(input_df, to_extend_name="Value", order=(3, 0, 0), seasonal_order=(3, 0, 0, 62), number_of_data_points_to_analyze=62) + ArimaPrediction( + input_df, + to_extend_name="Value", + order=(3, 0, 0), + seasonal_order=(3, 0, 0, 62), + number_of_data_points_to_analyze=62, + ) def test_single_column_prediction_arima(spark_session: SparkSession, historic_data): @@ -314,7 +339,7 @@ def test_single_column_prediction_arima(spark_session: SparkSession, historic_da seasonal_order=(3, 0, 0, 62), timestamp_name="EventTime", source_name="TagName", - status_name="Status" + status_name="Status", ) forecasted_df = arima_comp.filter() # print(forecasted_df.show(forecasted_df.count(), False)) @@ -325,7 +350,9 @@ def test_single_column_prediction_arima(spark_session: SparkSession, historic_da assert forecasted_df.count() == (input_df.count() + h_a_l) -def test_single_column_prediction_auto_arima(spark_session: SparkSession, historic_data): +def test_single_column_prediction_auto_arima( + spark_session: SparkSession, historic_data +): schema = StructType( [ @@ -346,15 +373,15 @@ def test_single_column_prediction_auto_arima(spark_session: SparkSession, histor arima_comp = ArimaAutoPrediction( past_data=input_df, - #past_data_style=ArimaPrediction.InputStyle.SOURCE_BASED, - #value_name="Value", + # past_data_style=ArimaPrediction.InputStyle.SOURCE_BASED, + # value_name="Value", to_extend_name="-4O7LSSAM_3EA02:2GT7E02I_R_MP", number_of_data_points_to_analyze=input_df.count(), number_of_data_points_to_predict=h_a_l, - #timestamp_name="EventTime", - #source_name="TagName", - #status_name="Status", - seasonal=True + # timestamp_name="EventTime", + # source_name="TagName", + # status_name="Status", + seasonal=True, ) forecasted_df = arima_comp.filter() # print(forecasted_df.show(forecasted_df.count(), False)) @@ -369,7 +396,10 @@ def test_single_column_prediction_auto_arima(spark_session: SparkSession, histor assert arima_comp.source_name == "TagName" assert arima_comp.status_name == "Status" -def test_column_based_prediction_arima(spark_session: SparkSession, column_based_synthetic_data): + +def test_column_based_prediction_arima( + spark_session: SparkSession, column_based_synthetic_data +): schema = StructType( [ @@ -388,11 +418,11 @@ def test_column_based_prediction_arima(spark_session: SparkSession, column_based to_extend_name="PrimarySource", number_of_data_points_to_analyze=input_df.count(), number_of_data_points_to_predict=input_df.count(), - seasonal=True + seasonal=True, ) forecasted_df = arima_comp.filter() - #forecasted_df.show() + # forecasted_df.show() assert isinstance(forecasted_df, DataFrame) @@ -422,14 +452,14 @@ def test_arima_large_data_set(spark_session: SparkSession): print((input_df.count(), len(input_df.columns))) - count_signal = input_df.filter("TagName = \"R0:Z24WVP.0S10L\"").count() + count_signal = input_df.filter('TagName = "R0:Z24WVP.0S10L"').count() h_a_l = int(count_signal / 2) arima_comp = ArimaAutoPrediction( input_df, to_extend_name="R0:Z24WVP.0S10L", number_of_data_points_to_analyze=count_signal, - number_of_data_points_to_predict=h_a_l + number_of_data_points_to_predict=h_a_l, ) result_df = arima_comp.filter() @@ -438,9 +468,7 @@ def test_arima_large_data_set(spark_session: SparkSession): assert isinstance(result_df, DataFrame) - assert result_df.count() == pytest.approx( - (input_df.count() + h_a_l), rel=tolerance - ) + assert result_df.count() == pytest.approx((input_df.count() + h_a_l), rel=tolerance) def test_arima_wrong_datatype(spark_session: SparkSession): @@ -473,8 +501,7 @@ def test_arima_wrong_datatype(spark_session: SparkSession): test_df, to_extend_name="A2PS64V0J.:ZUX09R", number_of_data_points_to_analyze=count_signal, - number_of_data_points_to_predict=h_a_l + number_of_data_points_to_predict=h_a_l, ) arima_comp.validate(expected_schema) -