Skip to content

Commit

Permalink
Moved n_components from cluster estimation into config.yml (#234)
Browse files Browse the repository at this point in the history
* moved n_components in cluster estimation to config.yaml

* moved n_components from cluster estimation to config.yaml

* added check verifying max_num_components > 0

* Refined changes to pass tests

---------

Co-authored-by: Jane <[email protected]>
  • Loading branch information
lilywang899 and janez45 authored Jan 18, 2025
1 parent 86cc4f2 commit 8633b43
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 6 deletions.
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ geolocation:
cluster_estimation:
min_activation_threshold: 25
min_new_points_to_run: 5
max_num_components: 10
random_state: 0

communications:
Expand Down
8 changes: 7 additions & 1 deletion main_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def main() -> int:

MIN_ACTIVATION_THRESHOLD = config["cluster_estimation"]["min_activation_threshold"]
MIN_NEW_POINTS_TO_RUN = config["cluster_estimation"]["min_new_points_to_run"]
MAX_NUM_COMPONENTS = config["cluster_estimation"]["max_num_components"]
RANDOM_STATE = config["cluster_estimation"]["random_state"]

COMMUNICATIONS_TIMEOUT = config["communications"]["timeout"]
Expand Down Expand Up @@ -327,7 +328,12 @@ def main() -> int:
result, cluster_estimation_worker_properties = worker_manager.WorkerProperties.create(
count=1,
target=cluster_estimation_worker.cluster_estimation_worker,
work_arguments=(MIN_ACTIVATION_THRESHOLD, MIN_NEW_POINTS_TO_RUN, RANDOM_STATE),
work_arguments=(
MIN_ACTIVATION_THRESHOLD,
MIN_NEW_POINTS_TO_RUN,
MAX_NUM_COMPONENTS,
RANDOM_STATE,
),
input_queues=[geolocation_to_cluster_estimation_queue],
output_queues=[cluster_estimation_to_communications_queue],
controller=controller,
Expand Down
15 changes: 11 additions & 4 deletions modules/cluster_estimation/cluster_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class ClusterEstimation:
min_new_points_to_run: int
Minimum number of new data points that must be collected before running model.
max_num_components: int
Max number of real landing pads.
random_state: int
Seed for randomizer, to get consistent results.
Expand Down Expand Up @@ -62,9 +65,6 @@ class ClusterEstimation:
__MEAN_PRECISION_PRIOR = 1e-6
__MAX_MODEL_ITERATIONS = 1000

# Real-world scenario Hyperparameters
__MAX_NUM_COMPONENTS = 10 # assumed maximum number of real landing pads

# Hyperparameters to clean up model outputs
__WEIGHT_DROP_THRESHOLD = 0.1
__MAX_COVARIANCE_THRESHOLD = 10
Expand All @@ -74,6 +74,7 @@ def create(
cls,
min_activation_threshold: int,
min_new_points_to_run: int,
max_num_components: int,
random_state: int,
local_logger: logger.Logger,
) -> "tuple[bool, ClusterEstimation | None]":
Expand All @@ -88,10 +89,15 @@ def create(
if min_activation_threshold < 1:
return False, None

# This must be greater than 0
if max_num_components < 0:
return False, None

return True, ClusterEstimation(
cls.__create_key,
min_activation_threshold,
min_new_points_to_run,
max_num_components,
random_state,
local_logger,
)
Expand All @@ -101,6 +107,7 @@ def __init__(
class_private_create_key: object,
min_activation_threshold: int,
min_new_points_to_run: int,
max_num_components: int,
random_state: int,
local_logger: logger.Logger,
) -> None:
Expand All @@ -112,7 +119,7 @@ def __init__(
# Initializes VGMM
self.__vgmm = sklearn.mixture.BayesianGaussianMixture(
covariance_type=self.__COVAR_TYPE,
n_components=self.__MAX_NUM_COMPONENTS,
n_components=max_num_components,
init_params=self.__MODEL_INIT_PARAM,
weight_concentration_prior=self.__WEIGHT_CONCENTRATION_PRIOR,
mean_precision_prior=self.__MEAN_PRECISION_PRIOR,
Expand Down
5 changes: 5 additions & 0 deletions modules/cluster_estimation/cluster_estimation_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def cluster_estimation_worker(
min_activation_threshold: int,
min_new_points_to_run: int,
max_num_components: int,
random_state: int,
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
output_queue: queue_proxy_wrapper.QueueProxyWrapper,
Expand All @@ -30,6 +31,9 @@ def cluster_estimation_worker(
min_new_points_to_run: int
Minimum number of new data points that must be collected before running model.
max_num_components: int
Max number of real landing pads.
random_state: int
Seed for randomizer, to get consistent results.
Expand All @@ -56,6 +60,7 @@ def cluster_estimation_worker(
result, estimator = cluster_estimation.ClusterEstimation.create(
min_activation_threshold,
min_new_points_to_run,
max_num_components,
random_state,
local_logger,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_cluster_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

MIN_TOTAL_POINTS_THRESHOLD = 100
MIN_NEW_POINTS_TO_RUN = 10
MAX_NUM_COMPONENTS = 10
RNG_SEED = 0
CENTRE_BOX_SIZE = 500


# Test functions use test fixture signature names and access class privates
# No enable
# pylint: disable=protected-access,redefined-outer-name
Expand All @@ -34,6 +34,7 @@ def cluster_model() -> cluster_estimation.ClusterEstimation: # type: ignore
result, model = cluster_estimation.ClusterEstimation.create(
MIN_TOTAL_POINTS_THRESHOLD,
MIN_NEW_POINTS_TO_RUN,
MAX_NUM_COMPONENTS,
RNG_SEED,
test_logger,
)
Expand Down

0 comments on commit 8633b43

Please sign in to comment.