Skip to content

Commit

Permalink
Add pyright github action (#118)
Browse files Browse the repository at this point in the history
* Add pyright github action

Adds type checking with pyright

* Force color in pyright

* Fix pyright issues

* Fix overly long line

* Fix pyright issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jamt9000 and pre-commit-ci[bot] authored Jan 21, 2025
1 parent 89b4471 commit 0a75709
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 3 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Pyright Type Checking

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]

jobs:
pyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install dependencies
run: |
pip install .[dev]
- name: Run pyright
run: pyright
env:
FORCE_COLOR: "1"
3 changes: 2 additions & 1 deletion preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def create_val_set(csv_file, val_fraction):
out of it specified by val_fraction.
"""
csv_file = Path(csv_file)
dataset = pd.read_csv(csv_file)
dataset: pd.DataFrame = pd.read_csv(csv_file)
np.random.seed(0)
dataset_mod = dataset[dataset.toxic != -1]
indices = np.random.rand(len(dataset_mod)) > val_fraction
val_set = dataset_mod[~indices]
output_file = csv_file.parent / "val.csv"
logger.info("Validation set saved to %s", output_file)
assert isinstance(val_set, (pd.DataFrame, pd.Series))
val_set.to_csv(output_file)


Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ dev = [
"scikit-learn >= 0.23.2",
"tqdm",
"pre-commit",
"numpy>=2"
"numpy>=2",
"pyright"
]

[tool.ruff]
Expand Down
5 changes: 4 additions & 1 deletion run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def run(model_name, input_obj, dest_file, from_ckpt, device="cpu"):
model = Detoxify(checkpoint=from_ckpt, device=device)
res = model.predict(text)

res_df = pd.DataFrame(res, index=[text] if isinstance(text, str) else text).round(5)
res_df = pd.DataFrame(
res,
index=[text] if isinstance(text, str) else text, # pyright: ignore[reportArgumentType]
).round(5)
print(res_df)
if dest_file is not None:
res_df.index.name = "input_text"
Expand Down
4 changes: 4 additions & 0 deletions src/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def load_data(self, csv_file):
filtered_change_names = {k: v for k, v in change_names.items() if k in final_df.columns}
if len(filtered_change_names) > 0:
final_df.rename(columns=filtered_change_names, inplace=True)
else:
raise TypeError("Invalid input type for csv_file, must be a string or a list of strings")
return final_df

def load_val(self, test_csv_file, add_labels=False):
Expand Down Expand Up @@ -155,6 +157,8 @@ def __getitem__(self, index):
meta["text_id"] = text_id

if self.train:
if self.weights is None:
raise Exception("self.weights must not be None")
meta["weights"] = self.weights[index]
toxic_weight = self.weights[index] * self.loss_weight * 1.0 / len(self.classes)
identity_weight = (1 - self.loss_weight) * 1.0 / len(self.identity_classes)
Expand Down

0 comments on commit 0a75709

Please sign in to comment.