Skip to content

Commit

Permalink
Merge pull request #101 from amosproj/refactor/63-anomaly-detection-t…
Browse files Browse the repository at this point in the history
…ests

refactor/63 anomaly detection tests
  • Loading branch information
dh1542 authored Dec 17, 2024
2 parents c7658a0 + 8668431 commit cb036bc
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
Libraries,
SystemType,
)
from pyspark.sql.types import (
DoubleType,
StructType,
StructField,
)


class KSigmaAnomalyDetection(DataManipulationBaseInterface, InputValidator):
Expand Down Expand Up @@ -63,13 +68,19 @@ def __init__(
raise Exception("You must provide at least one column name")
if len(column_names) > 1:
raise NotImplemented("Multiple columns are not supported yet")
self.column_names = column_names

self.column_names = column_names
self.use_median = use_median
self.spark = spark
self.df = df
self.k_value = k_value

self.validate(
StructType(
[StructField(column, DoubleType(), True) for column in column_names]
)
)

@staticmethod
def system_type():
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pyspark.sql import SparkSession

import pytest
from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.k_sigma_anomaly_detection import (
KSigmaAnomalyDetection,
)
import os

# Normal data mean=10 stddev=5 + 3 anomalies
# fmt: off
Expand Down Expand Up @@ -73,9 +74,6 @@ def test_filter_with_median(spark_session: SparkSession):
use_median=True,
).filter()

normal_expected_df.show()
normal_filtered_df.show()

assert normal_expected_df.collect() == normal_filtered_df.collect()

# Test with data that has an anomaly that shifts the mean significantly
Expand All @@ -91,3 +89,39 @@ def test_filter_with_median(spark_session: SparkSession):
).filter()

assert expected_df.collect() == filtered_df.collect()


def test_filter_with_wrong_types(spark_session: SparkSession):
wrong_column_type_df = spark_session.createDataFrame(
[(f"string {i}",) for i in range(10)], schema=["value"]
)

# wrong value type
with pytest.raises(ValueError):
KSigmaAnomalyDetection(
spark_session,
wrong_column_type_df,
column_names=["value"],
k_value=3,
use_median=True,
).filter()

# missing column
with pytest.raises(ValueError):
KSigmaAnomalyDetection(
spark_session,
wrong_column_type_df,
column_names=["$value"],
k_value=3,
use_median=True,
).filter()


def test_large_dataset(spark_session):
base_path = os.path.dirname(__file__)
file_path = os.path.join(base_path, "../../test_data.csv")
df = spark_session.read.option("header", "true").csv(file_path)

assert df.count() > 0, "Dataframe was not loaded correct"

KSigmaAnomalyDetection(spark_session, df, column_names=["Value"]).filter()

0 comments on commit cb036bc

Please sign in to comment.