Skip to content

Commit

Permalink
Add Ruff formatting & linting (#169)
Browse files Browse the repository at this point in the history
* add ruff to development dependencies and run in CI

* address `ruff check` lints

* run `uv run ruff format`

* enable Ruff lints for import statements, and format them by running `uv run ruff check --fix`

* build-n-publish needs check

* fix branch name in CI

* update readme, remove black add ruff
  • Loading branch information
NiklasRosenstein authored Jan 31, 2025
1 parent 8b400ea commit 92b9895
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 270 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/publish-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ jobs:
version: "0.5.21"
- name: Pytest
run: uv run pytest
- name: Ruff check
run: uv run ruff check
- name: Ruff format
run: uv run ruff format --check

build-n-publish:
name: Build and publish Python distributions to PyPI
Expand Down
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,15 @@ If no arguments are supplied pytr will look for them in the file `~/.pytr/creden

## Linting and Code Formatting

This project uses [black](https://github.com/psf/black) for code linting and auto-formatting. You can auto-format the code by running:
This project uses [Ruff](https://astral.sh/ruff) for code linting and auto-formatting. You can auto-format the code by running:

```bash
# Install black if not already installed
pip install black

# Auto-format code
black ./pytr
uv run ruff format # Format code
uv run ruff check --fix-only # Remove unneeded imports, order imports, etc.
```

Ruff runs as part of CI and your Pull Request cannot be merged unless it satisfies the linting and formatting checks.

## Setting Up a Development Environment

1. Clone the repository:
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,12 @@ locale_dir = "pytr/locale"

[dependency-groups]
dev = [
"ruff>=0.9.4",
"pytest>=8.3.4",
]

[tool.ruff]
line-length = 120

[tool.ruff.lint]
extend-select = ["I"]
19 changes: 7 additions & 12 deletions pytr/account.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import json
import sys
from pygments import highlight, lexers, formatters
import time
from getpass import getpass

from pytr.api import TradeRepublicApi, CREDENTIALS_FILE
from pygments import formatters, highlight, lexers

from pytr.api import CREDENTIALS_FILE, TradeRepublicApi
from pytr.utils import get_logger


def get_settings(tr):
formatted_json = json.dumps(tr.settings(), indent=2)
if sys.stdout.isatty():
colorful_json = highlight(
formatted_json, lexers.JsonLexer(), formatters.TerminalFormatter()
)
colorful_json = highlight(formatted_json, lexers.JsonLexer(), formatters.TerminalFormatter())
return colorful_json
else:
return formatted_json
Expand Down Expand Up @@ -41,9 +40,7 @@ def login(phone_no=None, pin=None, web=True, store_credentials=False):
CREDENTIALS_FILE.parent.mkdir(parents=True, exist_ok=True)
if phone_no is None:
log.info("Credentials file not found")
print(
"Please enter your TradeRepublic phone number in the format +4912345678:"
)
print("Please enter your TradeRepublic phone number in the format +4912345678:")
phone_no = input()
else:
log.info("Phone number provided as argument")
Expand Down Expand Up @@ -74,15 +71,13 @@ def login(phone_no=None, pin=None, web=True, store_credentials=False):
exit(1)
request_time = time.time()
print("Enter the code you received to your mobile app as a notification.")
print(
f"Enter nothing if you want to receive the (same) code as SMS. (Countdown: {countdown})"
)
print(f"Enter nothing if you want to receive the (same) code as SMS. (Countdown: {countdown})")
code = input("Code: ")
if code == "":
countdown = countdown - (time.time() - request_time)
for remaining in range(int(countdown)):
print(
f"Need to wait {int(countdown-remaining)} seconds before requesting SMS...",
f"Need to wait {int(countdown - remaining)} seconds before requesting SMS...",
end="\r",
)
time.sleep(1)
Expand Down
20 changes: 5 additions & 15 deletions pytr/alarms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from datetime import datetime

from pytr.utils import preview, get_logger
from pytr.utils import get_logger, preview


class Alarms:
Expand All @@ -19,9 +19,7 @@ async def alarms_loop(self):
recv += 1
self.alarms = response
else:
print(
f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}"
)
print(f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}")

if recv == 1:
return
Expand All @@ -36,9 +34,7 @@ async def ticker_loop(self):
recv += 1
self.alarms = response
else:
print(
f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}"
)
print(f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}")

if recv == 1:
return
Expand All @@ -47,11 +43,7 @@ def overview(self):
print("ISIN status created target diff% createdAt triggeredAT")
self.log.debug(f"Processing {len(self.alarms)} alarms")

for (
a
) in (
self.alarms
): # sorted(positions, key=lambda x: x['netValue'], reverse=True):
for a in self.alarms: # sorted(positions, key=lambda x: x['netValue'], reverse=True):
self.log.debug(f" Processing {a} alarm")
ts = int(a["createdAt"]) / 1000.0
target_price = float(a["targetPrice"])
Expand All @@ -61,9 +53,7 @@ def overview(self):
triggered = "-"
else:
ts = int(a["triggeredAt"]) / 1000.0
triggered = datetime.fromtimestamp(ts).isoformat(
sep=" ", timespec="minutes"
)
triggered = datetime.fromtimestamp(ts).isoformat(sep=" ", timespec="minutes")

if a["createdPrice"] == 0:
diffP = 0.0
Expand Down
91 changes: 23 additions & 68 deletions pytr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@
import hashlib
import json
import pathlib
import ssl
import time
import urllib.parse
import uuid
from http.cookiejar import MozillaCookieJar

import certifi
import ssl
import requests
import websockets
from ecdsa import NIST256p, SigningKey
from ecdsa.util import sigencode_der
from http.cookiejar import MozillaCookieJar

from pytr.utils import get_logger


home = pathlib.Path.home()
BASE_DIR = home / ".pytr"
CREDENTIALS_FILE = BASE_DIR / "credentials"
Expand Down Expand Up @@ -96,9 +96,7 @@ def __init__(
self._locale = locale
self._save_cookies = save_cookies

self._credentials_file = (
pathlib.Path(credentials_file) if credentials_file else CREDENTIALS_FILE
)
self._credentials_file = pathlib.Path(credentials_file) if credentials_file else CREDENTIALS_FILE

if not (phone_no and pin):
try:
Expand All @@ -107,18 +105,12 @@ def __init__(
self.phone_no = lines[0].strip()
self.pin = lines[1].strip()
except FileNotFoundError:
raise ValueError(
f"phone_no and pin must be specified explicitly or via {self._credentials_file}"
)
raise ValueError(f"phone_no and pin must be specified explicitly or via {self._credentials_file}")
else:
self.phone_no = phone_no
self.pin = pin

self._cookies_file = (
pathlib.Path(cookies_file)
if cookies_file
else BASE_DIR / f"cookies.{self.phone_no}.txt"
)
self._cookies_file = pathlib.Path(cookies_file) if cookies_file else BASE_DIR / f"cookies.{self.phone_no}.txt"

self.keyfile = keyfile if keyfile else KEY_FILE
try:
Expand Down Expand Up @@ -231,9 +223,7 @@ def complete_weblogin(self, verify_code):
if not self._process_id and not self._websession:
raise ValueError("Initiate web login first.")

r = self._websession.post(
f"{self._host}/api/v1/auth/web/login/{self._process_id}/{verify_code}"
)
r = self._websession.post(f"{self._host}/api/v1/auth/web/login/{self._process_id}/{verify_code}")
r.raise_for_status()
self.save_websession()
self._weblogin = True
Expand Down Expand Up @@ -270,9 +260,7 @@ def _web_request(self, url_path, payload=None, method="GET"):
r = self._websession.get(f"{self._host}/api/v1/auth/web/session")
r.raise_for_status()
self._web_session_token_expires_at = time.time() + 290
return self._websession.request(
method=method, url=f"{self._host}{url_path}", data=payload
)
return self._websession.request(method=method, url=f"{self._host}{url_path}", data=payload)

async def _get_ws(self):
if self._ws and self._ws.open:
Expand Down Expand Up @@ -301,9 +289,7 @@ async def _get_ws(self):
}
connect_id = 31

self._ws = await websockets.connect(
"wss://api.traderepublic.com", ssl=ssl_context, extra_headers=extra_headers
)
self._ws = await websockets.connect("wss://api.traderepublic.com", ssl=ssl_context, extra_headers=extra_headers)
await self._ws.send(f"connect {connect_id} {json.dumps(connection_message)}")
response = await self._ws.recv()

Expand Down Expand Up @@ -354,9 +340,7 @@ async def recv(self):

if subscription_id not in self.subscriptions:
if code != "C":
self.log.debug(
f"No active subscription for id {subscription_id}, dropping message"
)
self.log.debug(f"No active subscription for id {subscription_id}, dropping message")
continue
subscription = self.subscriptions[subscription_id]

Expand Down Expand Up @@ -408,16 +392,12 @@ async def _receive_one(self, fut, timeout):
subscription_id = await fut

try:
return await asyncio.wait_for(
self._recv_subscription(subscription_id), timeout
)
return await asyncio.wait_for(self._recv_subscription(subscription_id), timeout)
finally:
await self.unsubscribe(subscription_id)

def run_blocking(self, fut, timeout=5.0):
return asyncio.get_event_loop().run_until_complete(
self._receive_one(fut, timeout=timeout)
)
return asyncio.get_event_loop().run_until_complete(self._receive_one(fut, timeout=timeout))

async def portfolio(self):
return await self.subscribe({"type": "portfolio"})
Expand All @@ -437,21 +417,14 @@ async def cash(self):
async def available_cash_for_payout(self):
return await self.subscribe({"type": "availableCashForPayout"})

async def portfolio_status(self):
return await self.subscribe({"type": "portfolioStatus"})

async def portfolio_history(self, timeframe):
return await self.subscribe(
{"type": "portfolioAggregateHistory", "range": timeframe}
)
return await self.subscribe({"type": "portfolioAggregateHistory", "range": timeframe})

async def instrument_details(self, isin):
return await self.subscribe({"type": "instrument", "id": isin})

async def instrument_suitability(self, isin):
return await self.subscribe(
{"type": "instrumentSuitability", "instrumentId": isin}
)
return await self.subscribe({"type": "instrumentSuitability", "instrumentId": isin})

async def stock_details(self, isin):
return await self.subscribe({"type": "stockDetails", "id": isin})
Expand All @@ -460,19 +433,15 @@ async def add_watchlist(self, isin):
return await self.subscribe({"type": "addToWatchlist", "instrumentId": isin})

async def remove_watchlist(self, isin):
return await self.subscribe(
{"type": "removeFromWatchlist", "instrumentId": isin}
)
return await self.subscribe({"type": "removeFromWatchlist", "instrumentId": isin})

async def ticker(self, isin, exchange="LSX"):
return await self.subscribe({"type": "ticker", "id": f"{isin}.{exchange}"})

async def performance(self, isin, exchange="LSX"):
return await self.subscribe({"type": "performance", "id": f"{isin}.{exchange}"})

async def performance_history(
self, isin, timeframe, exchange="LSX", resolution=None
):
async def performance_history(self, isin, timeframe, exchange="LSX", resolution=None):
parameters = {
"type": "aggregateHistory",
"id": f"{isin}.{exchange}",
Expand Down Expand Up @@ -501,9 +470,7 @@ async def timeline_detail_order(self, order_id):
return await self.subscribe({"type": "timelineDetail", "orderId": order_id})

async def timeline_detail_savings_plan(self, savings_plan_id):
return await self.subscribe(
{"type": "timelineDetail", "savingsPlanId": savings_plan_id}
)
return await self.subscribe({"type": "timelineDetail", "savingsPlanId": savings_plan_id})

async def timeline_transactions(self, after=None):
return await self.subscribe({"type": "timelineTransactions", "after": after})
Expand All @@ -518,9 +485,7 @@ async def search_tags(self):
return await self.subscribe({"type": "neonSearchTags"})

async def search_suggested_tags(self, query):
return await self.subscribe(
{"type": "neonSearchSuggestedTags", "data": {"q": query}}
)
return await self.subscribe({"type": "neonSearchSuggestedTags", "data": {"q": query}})

async def search(
self,
Expand All @@ -546,17 +511,11 @@ async def search(
if filter_index:
search_parameters["filter"].append({"key": "index", "value": filter_index})
if filter_country:
search_parameters["filter"].append(
{"key": "country", "value": filter_country}
)
search_parameters["filter"].append({"key": "country", "value": filter_country})
if filter_region:
search_parameters["filter"].append(
{"key": "region", "value": filter_region}
)
search_parameters["filter"].append({"key": "region", "value": filter_region})
if filter_sector:
search_parameters["filter"].append(
{"key": "sector", "value": filter_sector}
)
search_parameters["filter"].append({"key": "sector", "value": filter_sector})

search_type = "neonSearch" if not aggregate else "neonSearchAggregations"
return await self.subscribe({"type": search_type, "data": search_parameters})
Expand Down Expand Up @@ -750,17 +709,13 @@ async def change_savings_plan(
return await self.subscribe(parameters)

async def cancel_savings_plan(self, savings_plan_id):
return await self.subscribe(
{"type": "cancelSavingsPlan", "id": savings_plan_id}
)
return await self.subscribe({"type": "cancelSavingsPlan", "id": savings_plan_id})

async def price_alarm_overview(self):
return await self.subscribe({"type": "priceAlarms"})

async def create_price_alarm(self, isin, price):
return await self.subscribe(
{"type": "createPriceAlarm", "instrumentId": isin, "targetPrice": price}
)
return await self.subscribe({"type": "createPriceAlarm", "instrumentId": isin, "targetPrice": price})

async def cancel_price_alarm(self, price_alarm_id):
return await self.subscribe({"type": "cancelPriceAlarm", "id": price_alarm_id})
Expand Down
Loading

0 comments on commit 92b9895

Please sign in to comment.