Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit autoupdate (pymc-labs#1313)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.8.3 → v0.8.4](astral-sh/ruff-pre-commit@v0.8.3...v0.8.4)
- [github.com/pre-commit/mirrors-mypy: v1.13.0 → v1.14.0](pre-commit/mirrors-mypy@v1.13.0...v1.14.0)

* remove precision numpy typing

* minor fix

* ignore error

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Juan Orduz <[email protected]>
Co-authored-by: Juan Orduz <[email protected]>
  • Loading branch information
3 people authored Dec 28, 2024
1 parent f4fe828 commit 1630589
Show file tree
Hide file tree
Showing 8 changed files with 1,835 additions and 1,837 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- --exclude=docs/
- --exclude=scripts/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.4
hooks:
- id: ruff
types_or: [python, pyi, jupyter]
Expand All @@ -21,7 +21,7 @@ repos:
types_or: [python, pyi, jupyter]
exclude: ^docs/source/notebooks/clv/dev/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.0
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
3,628 changes: 1,814 additions & 1,814 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pymc_marketing/clv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,14 @@ def rfm_summary(
customers["frequency"] = customers["count"] - 1

customers["recency"] = (
(pandas.to_datetime(customers["max"]) - pandas.to_datetime(customers["min"])) # type: ignore
/ np.timedelta64(1, time_unit)
(pandas.to_datetime(customers["max"]) - pandas.to_datetime(customers["min"]))
/ np.timedelta64(1, time_unit) # type: ignore[call-overload]
/ time_scaler
)

customers["T"] = (
(observation_period_end_ts - customers["min"])
/ np.timedelta64(1, time_unit)
/ np.timedelta64(1, time_unit) # type: ignore[call-overload]
/ time_scaler
)

Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

Index = Sequence[int]
Indices = dict[str, Index]
Values = npt.NDArray[np.int_] | npt.NDArray[np.float64] | npt.NDArray[np.str_]
Values = npt.NDArray[np.int_] | npt.NDArray | npt.NDArray[np.str_]


def _find_unaligned_values(same_value: npt.NDArray[np.int_]) -> list[int]:
Expand Down
14 changes: 6 additions & 8 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ def create_idata_attrs(self) -> dict[str, str]:

return attrs

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float64]
) -> pt.TensorVariable:
def forward_pass(self, x: pt.TensorVariable | npt.NDArray) -> pt.TensorVariable:
"""Transform channel input into target contributions of each channel.
This method handles the ordering of the adstock and saturation
Expand All @@ -322,7 +320,7 @@ def forward_pass(
Parameters
----------
x : pt.TensorVariable | npt.NDArray[np.float64]
x : pt.TensorVariable | npt.NDArray
The channel input which could be spends or impressions
Returns
Expand Down Expand Up @@ -586,9 +584,9 @@ def default_model_config(self) -> dict:

def channel_contributions_forward_pass(
self,
channel_data: npt.NDArray[np.float64],
channel_data: npt.NDArray,
disable_logger_stdout: bool | None = False,
) -> npt.NDArray[np.float64]:
) -> npt.NDArray:
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
Parameters
Expand Down Expand Up @@ -945,9 +943,9 @@ class MMM(

def channel_contributions_forward_pass(
self,
channel_data: npt.NDArray[np.float64],
channel_data: npt.NDArray,
disable_logger_stdout: bool | None = False,
) -> npt.NDArray[np.float64]:
) -> npt.NDArray:
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
We return the contribution in the original scale of the target variable.
Expand Down
10 changes: 5 additions & 5 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def weibull_adstock(
return batched_convolution(x, w, axis=axis, mode=mode)


def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5):
def logistic_saturation(x, lam: npt.NDArray | float = 0.5):
r"""Logistic saturation transformation.
.. math::
Expand Down Expand Up @@ -492,7 +492,7 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5):


def inverse_scaled_logistic_saturation(
x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3)
x, lam: npt.NDArray | float = 0.5, eps: float = np.log(3)
):
r"""Inverse scaled logistic saturation transformation.
Expand Down Expand Up @@ -827,9 +827,9 @@ def tanh_saturation_baselined(


def michaelis_menten(
x: float | np.ndarray | npt.NDArray[np.float64],
alpha: float | np.ndarray | npt.NDArray[np.float64],
lam: float | np.ndarray | npt.NDArray[np.float64],
x: float | np.ndarray | npt.NDArray,
alpha: float | np.ndarray | npt.NDArray,
lam: float | np.ndarray | npt.NDArray,
) -> float | Any:
r"""Evaluate the Michaelis-Menten function for given values of x, alpha, and lambda.
Expand Down
6 changes: 3 additions & 3 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def transform_1d_array(


def sigmoid_saturation(
x: float | np.ndarray | npt.NDArray[np.float64],
alpha: float | np.ndarray | npt.NDArray[np.float64],
lam: float | np.ndarray | npt.NDArray[np.float64],
x: float | np.ndarray | npt.NDArray,
alpha: float | np.ndarray | npt.NDArray,
lam: float | np.ndarray | npt.NDArray,
) -> float | Any:
"""Sigmoid saturation function.
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def random_samples(
combinations = list(product(range(n_chains), range(n_draws)))

return [
tuple(pair) for pair in rng.choice(combinations, size=n, replace=False).tolist()
tuple(pair) for pair in list(rng.choice(combinations, size=n, replace=False))
]


Expand Down

0 comments on commit 1630589

Please sign in to comment.