Skip to content

Commit

Permalink
Merge pull request #4 from moritzwilksch/fixes
Browse files Browse the repository at this point in the history
fixes + env
  • Loading branch information
moritzwilksch authored Mar 17, 2024
2 parents a0f2178 + 277bc1c commit 0403249
Show file tree
Hide file tree
Showing 7 changed files with 1,509 additions and 11 deletions.
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# GitHub syntax highlighting
pixi.lock linguist-language=YAML

3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,6 @@ dmypy.json

# Pyre type checker
.pyre/
# pixi environments
.pixi

1,468 changes: 1,468 additions & 0 deletions pixi.lock

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[project]
name = "pyfin-sentiment"
version = "0.1.0"
description = "Add a short description here"
authors = ["Moritz Wilksch <[email protected]>"]
channels = ["conda-forge"]
platforms = ["linux-64"]

[tasks]
postinstall = "pip install --no-build-isolation -e ."

[dependencies]
pip = ">=24.0,<25"
polars = ">=0.20.15,<0.21"
numpy = ">=1.26.4,<1.27"
scikit-learn = ">=1.4.1.post1,<1.4.2"


[feature.dev.dependencies]
pytest = ">=7.2.0,<8.2"
twine = ">=5.0.0,<5.1"

[environments]
dev = ["dev"]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ isort
toml
joblib
scikit-learn
polars==0.12.0
polars==0.20.*
6 changes: 3 additions & 3 deletions src/pyfin_sentiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def predict(self, texts: Union[list, np.ndarray]) -> np.ndarray:
if len(texts) == 0:
raise ValueError(f"Please provide at least one text. Got {texts}")

df = pl.DataFrame(texts, columns=["text"])
df = pl.DataFrame(texts, schema=["text"])
df = self.preprocessor.process(df)

return self.model.predict(df["text"].to_list())
Expand All @@ -157,7 +157,7 @@ def predict_proba(self, texts: Union[list, np.ndarray]) -> np.ndarray:
f"Please provide a list or np.ndarray of texts. Got {type(texts)}."
)

df = pl.DataFrame(texts, columns=["text"])
df = pl.DataFrame(texts, schema=["text"])
df = self.preprocessor.process(df)

return self.model.predict_proba(df["text"].to_list())
return self.model.predict_proba(df["text"].to_list())
14 changes: 7 additions & 7 deletions src/pyfin_sentiment/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,34 @@ def __init__(
self.multi_spaces = multi_spaces

def fix_symbols(self, df: pl.DataFrame) -> pl.DataFrame:
return df.with_column(
return df.with_columns(
pl.col(self.TEXTCOL)
.str.replace_all(r"&gt;", ">")
.str.replace_all(r"&lt;", "<")
.str.replace_all(r"&amp;", "&")
)

def prep_numbers(self, df: pl.DataFrame) -> pl.DataFrame:
return df.with_column(pl.col(self.TEXTCOL).str.replace_all(r"\d", "9"))
return df.with_columns(pl.col(self.TEXTCOL).str.replace_all(r"\d", "9"))

def prep_cashtags(self, df: pl.DataFrame) -> pl.DataFrame:
return df.with_column(
return df.with_columns(
pl.col(self.TEXTCOL).str.replace_all(self.CASHTAG_REGEX, "TICKER")
)

def prep_mentions(self, df: pl.DataFrame) -> pl.DataFrame:
return df.with_column(
return df.with_columns(
pl.col(self.TEXTCOL).str.replace_all(self.MENTION_REGEX, "@USER")
)

def prep_lowercase(self, df: pl.DataFrame) -> pl.DataFrame:
return df.with_column(pl.col(self.TEXTCOL).str.to_lowercase())
return df.with_columns(pl.col(self.TEXTCOL).str.to_lowercase())

def prep_newlines(self, df: pl.DataFrame) -> pl.DataFrame:
return df.with_column(pl.col(self.TEXTCOL).str.replace_all("\n", " "))
return df.with_columns(pl.col(self.TEXTCOL).str.replace_all("\n", " "))

def prep_multi_spaces(self, df: pl.DataFrame) -> pl.DataFrame:
return df.with_column(pl.col(self.TEXTCOL).str.replace_all(r"\s+", " "))
return df.with_columns(pl.col(self.TEXTCOL).str.replace_all(r"\s+", " "))

def process(self, df: pl.DataFrame) -> pl.DataFrame:
if self.symbols:
Expand Down

0 comments on commit 0403249

Please sign in to comment.