diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..8150b0b --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,27 @@ +name: Pyright Type Checking + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + pyright: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + pip install .[dev] + + - name: Run pyright + run: pyright + env: + FORCE_COLOR: "1" diff --git a/preprocessing_utils.py b/preprocessing_utils.py index 437db69..d68045b 100644 --- a/preprocessing_utils.py +++ b/preprocessing_utils.py @@ -32,13 +32,14 @@ def create_val_set(csv_file, val_fraction): out of it specified by val_fraction. """ csv_file = Path(csv_file) - dataset = pd.read_csv(csv_file) + dataset: pd.DataFrame = pd.read_csv(csv_file) np.random.seed(0) dataset_mod = dataset[dataset.toxic != -1] indices = np.random.rand(len(dataset_mod)) > val_fraction val_set = dataset_mod[~indices] output_file = csv_file.parent / "val.csv" logger.info("Validation set saved to %s", output_file) + assert isinstance(val_set, (pd.DataFrame, pd.Series)) val_set.to_csv(output_file) diff --git a/pyproject.toml b/pyproject.toml index 09033ad..9325a86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,8 @@ dev = [ "scikit-learn >= 0.23.2", "tqdm", "pre-commit", - "numpy>=2" + "numpy>=2", + "pyright" ] [tool.ruff] diff --git a/run_prediction.py b/run_prediction.py index 2871e44..df22066 100644 --- a/run_prediction.py +++ b/run_prediction.py @@ -32,7 +32,10 @@ def run(model_name, input_obj, dest_file, from_ckpt, device="cpu"): model = Detoxify(checkpoint=from_ckpt, device=device) res = model.predict(text) - res_df = pd.DataFrame(res, index=[text] if isinstance(text, str) else text).round(5) + res_df = pd.DataFrame( + res, + index=[text] if isinstance(text, str) else text, # pyright: ignore[reportArgumentType] + ).round(5) print(res_df) if dest_file is not None: res_df.index.name = "input_text" diff --git a/src/data_loaders.py b/src/data_loaders.py index 2f58741..8796012 100644 --- a/src/data_loaders.py +++ b/src/data_loaders.py @@ -49,6 +49,8 @@ def load_data(self, csv_file): filtered_change_names = {k: v for k, v in change_names.items() if k in final_df.columns} if len(filtered_change_names) > 0: final_df.rename(columns=filtered_change_names, inplace=True) + else: + raise TypeError("Invalid input type for csv_file, must be a string or a list of strings") return final_df def load_val(self, test_csv_file, add_labels=False): @@ -155,6 +157,8 @@ def __getitem__(self, index): meta["text_id"] = text_id if self.train: + if self.weights is None: + raise Exception("self.weights must not be None") meta["weights"] = self.weights[index] toxic_weight = self.weights[index] * self.loss_weight * 1.0 / len(self.classes) identity_weight = (1 - self.loss_weight) * 1.0 / len(self.identity_classes)