Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

107 moving average done #125

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: src.sdk.python.rtdip_sdk.pipelines.data_quality.monitoring.spark.moving_average
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ nav:
- Identify Missing Data:
- Interval Based: sdk/code-reference/pipelines/data_quality/monitoring/spark/identify_missing_data_interval.md
- Pattern Based: sdk/code-reference/pipelines/data_quality/monitoring/spark/identify_missing_data_pattern.md
- Data Manipulation:
- Moving Average: sdk/code-reference/pipelines/data_quality/monitoring/spark/moving_average.md
- Data Manipulation:
- Duplicate Detection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/duplicate_detection.md
- Filter Out of Range Values: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/out_of_range_value_filter.md
- Flatline Filter: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/flatline_filter.md
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
from pyspark.sql import DataFrame as PySparkDataFrame
from pyspark.sql.functions import col, avg
from pyspark.sql.window import Window
from pyspark.sql.types import (
StructType,
StructField,
StringType,
TimestampType,
FloatType,
)

from src.sdk.python.rtdip_sdk.pipelines.data_quality.monitoring.interfaces import (
MonitoringBaseInterface,
)
from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import (
Libraries,
SystemType,
)
from ...input_validator import InputValidator


class MovingAverage(MonitoringBaseInterface, InputValidator):
"""
Computes and logs the moving average over a specified window size for a given PySpark DataFrame.

Args:
df (pyspark.sql.DataFrame): The DataFrame to process.
window_size (int): The size of the moving window.

Example:
```python
from pyspark.sql import SparkSession
from rtdip_sdk.pipelines.monitoring.spark.data_quality.moving_average import MovingAverage

spark = SparkSession.builder.master("local[1]").appName("MovingAverageExample").getOrCreate()

data = [
("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", 1.0),
("A2PS64V0J.:ZUX09R", "2024-01-02 07:53:11.000", "Good", 2.0),
("A2PS64V0J.:ZUX09R", "2024-01-02 11:56:42.000", "Good", 3.0),
("A2PS64V0J.:ZUX09R", "2024-01-02 16:00:12.000", "Good", 4.0),
("A2PS64V0J.:ZUX09R", "2024-01-02 20:03:46.000", "Good", 5.0),
]

columns = ["TagName", "EventTime", "Status", "Value"]

df = spark.createDataFrame(data, columns)

moving_avg = MovingAverage(
df=df,
window_size=3,
)

moving_avg.check()
```
"""

df: PySparkDataFrame
window_size: int
EXPECTED_SCHEMA = StructType(
[
StructField("TagName", StringType(), True),
StructField("EventTime", TimestampType(), True),
StructField("Status", StringType(), True),
StructField("Value", FloatType(), True),
]
)

def __init__(
self,
df: PySparkDataFrame,
window_size: int,
) -> None:
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError("window_size must be a positive integer.")

self.df = df
self.validate(self.EXPECTED_SCHEMA)
self.window_size = window_size

self.logger = logging.getLogger(self.__class__.__name__)
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)

@staticmethod
def system_type():
"""
Attributes:
SystemType (Environment): Requires PYSPARK
"""
return SystemType.PYSPARK

@staticmethod
def libraries():
libraries = Libraries()
return libraries

@staticmethod
def settings() -> dict:
return {}

def check(self) -> None:
"""
Computes and logs the moving average using a specified window size.
"""

self._validate_inputs()

window_spec = (
Window.partitionBy("TagName")
.orderBy("EventTime")
.rowsBetween(-(self.window_size - 1), 0)
)

self.logger.info("Computing moving averages:")

for row in (
self.df.withColumn("MovingAverage", avg(col("Value")).over(window_spec))
.select("TagName", "EventTime", "Value", "MovingAverage")
.collect()
):
self.logger.info(
f"Tag: {row.TagName}, Time: {row.EventTime}, Value: {row.Value}, Moving Avg: {row.MovingAverage}"
)

def _validate_inputs(self):
if not isinstance(self.window_size, int) or self.window_size <= 0:
raise ValueError("window_size must be a positive integer.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
import os
from pyspark.sql import SparkSession
from src.sdk.python.rtdip_sdk.pipelines.data_quality.monitoring.spark.moving_average import (
MovingAverage,
)
import logging
from io import StringIO


@pytest.fixture(scope="session")
def spark():
spark = (
SparkSession.builder.master("local[2]")
.appName("MovingAverageTest")
.getOrCreate()
)
yield spark
spark.stop()


@pytest.fixture
def log_capture():
log_stream = StringIO()
logger = logging.getLogger("MovingAverage")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(log_stream)
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
yield log_stream
logger.removeHandler(handler)
handler.close()


def test_moving_average_basic(spark, log_capture):
df = spark.createDataFrame(
[
("Tag1", "2024-01-02 03:49:45.000", "Good", 1.0),
("Tag1", "2024-01-02 07:53:11.000", "Good", 2.0),
("Tag1", "2024-01-02 11:56:42.000", "Good", 3.0),
("Tag1", "2024-01-02 16:00:12.000", "Good", 4.0),
("Tag1", "2024-01-02 20:03:46.000", "Good", 5.0),
],
["TagName", "EventTime", "Status", "Value"],
)

detector = MovingAverage(df, window_size=3)
detector.check()

expected_logs = [
"Computing moving averages:",
"Tag: Tag1, Time: 2024-01-02 03:49:45, Value: 1.0, Moving Avg: 1.0",
"Tag: Tag1, Time: 2024-01-02 07:53:11, Value: 2.0, Moving Avg: 1.5",
"Tag: Tag1, Time: 2024-01-02 11:56:42, Value: 3.0, Moving Avg: 2.0",
"Tag: Tag1, Time: 2024-01-02 16:00:12, Value: 4.0, Moving Avg: 3.0",
"Tag: Tag1, Time: 2024-01-02 20:03:46, Value: 5.0, Moving Avg: 4.0",
]

actual_logs = log_capture.getvalue().strip().split("\n")

assert len(expected_logs) == len(
actual_logs
), f"Expected {len(expected_logs)} logs, got {len(actual_logs)}"

for expected, actual in zip(expected_logs, actual_logs):
assert expected in actual, f"Expected: '{expected}', got: '{actual}'"


def test_moving_average_invalid_window_size(spark):
df = spark.createDataFrame(
[
("Tag1", "2024-01-02 03:49:45.000", "Good", 1.0),
("Tag1", "2024-01-02 07:53:11.000", "Good", 2.0),
],
["TagName", "EventTime", "Status", "Value"],
)

with pytest.raises(ValueError, match="window_size must be a positive integer."):
MovingAverage(df, window_size=-2)


def test_large_dataset(spark):
base_path = os.path.dirname(__file__)
file_path = os.path.join(base_path, "../../test_data.csv")
df = spark.read.option("header", "true").csv(file_path)

assert df.count() > 0, "DataFrame was nicht geladen."

detector = MovingAverage(df, window_size=5)
detector.check()
Loading