Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ClientStates enumeration #619

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ install:

script:
- make ci

notifications:
email: false

Expand All @@ -38,6 +38,7 @@ jobs:
- pip install --upgrade pip
- bash ./scripts/install_nats.sh
install:
- pip install nkeys
- pip install -e .[fast-mail-parser]
- name: "Python: 3.12"
python: "3.12"
Expand All @@ -48,6 +49,7 @@ jobs:
- pip install --upgrade pip
- bash ./scripts/install_nats.sh
install:
- pip install nkeys
- pip install -e .[fast-mail-parser]
- name: "Python: 3.11"
python: "3.11"
Expand All @@ -58,6 +60,7 @@ jobs:
- pip install --upgrade pip
- bash ./scripts/install_nats.sh
install:
- pip install nkeys
- pip install -e .[fast-mail-parser]
- name: "Python: 3.11/uvloop"
python: "3.11"
Expand All @@ -68,8 +71,8 @@ jobs:
- pip install --upgrade pip
- bash ./scripts/install_nats.sh
install:
- pip install nkeys uvloop
- pip install -e .[fast-mail-parser]
- pip install uvloop
- name: "Python: 3.11 (nats-server@main)"
python: "3.11"
env:
Expand All @@ -81,6 +84,7 @@ jobs:
- pip install --upgrade pip
- bash ./scripts/install_nats.sh
install:
- pip install nkeys
- pip install -e .[fast-mail-parser]
allow_failures:
- name: "Python: 3.8"
Expand Down
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
REPO_OWNER=nats-io
PROJECT_NAME=nats.py
SOURCE_CODE=nats
TEST_CODE=tests


help:
Expand All @@ -22,14 +23,17 @@ deps:

format:
yapf -i --recursive $(SOURCE_CODE)
yapf -i --recursive tests
yapf -i --recursive $(TEST_CODE)


test:
lint:
yapf --recursive --diff $(SOURCE_CODE)
yapf --recursive --diff tests
yapf --recursive --diff $(TEST_CODE)
mypy
flake8 ./nats/js/


test:
pytest


Expand Down
85 changes: 44 additions & 41 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import ipaddress
import json
import logging
import os
import ssl
import string
import time
from collections import UserString
from dataclasses import dataclass
from email.parser import BytesParser
from enum import Enum
from io import BytesIO
from random import shuffle
from secrets import token_hex
Expand Down Expand Up @@ -184,14 +186,10 @@ async def _default_error_callback(ex: Exception) -> None:
_logger.error("nats: encountered error", exc_info=ex)


class Client:
"""
Asyncio based client for NATS.
"""
# Client section

msg_class: type[Msg] = Msg

# FIXME: Use an enum instead.
class ClientState(Enum):
DISCONNECTED = 0
CONNECTED = 1
CLOSED = 2
Expand All @@ -200,6 +198,12 @@ class Client:
DRAINING_SUBS = 5
DRAINING_PUBS = 6


class Client:
"""Asyncio-based client for NATS."""

msg_class: type[Msg] = Msg

def __repr__(self) -> str:
return f"<nats client v{__version__}>"

Expand Down Expand Up @@ -230,7 +234,7 @@ def __init__(self) -> None:
self._client_id: Optional[int] = None
self._sid: int = 0
self._subs: Dict[int, Subscription] = {}
self._status: int = Client.DISCONNECTED
stankudrow marked this conversation as resolved.
Show resolved Hide resolved
self._status = ClientState.DISCONNECTED
self._ps: Parser = Parser(self)

# pending queue of commands that will be flushed to the server.
Expand Down Expand Up @@ -511,7 +515,7 @@ async def subscribe_handler(msg):
if not self.options["allow_reconnect"]:
raise e

await self._close(Client.DISCONNECTED, False)
await self._close(ClientState.DISCONNECTED, False)
if self._current_server is not None:
self._current_server.last_attempt = time.monotonic()
self._current_server.reconnects += 1
Expand All @@ -524,7 +528,6 @@ def _setup_nkeys_connect(self) -> None:

def _setup_nkeys_jwt_connect(self) -> None:
assert self._user_credentials, "_user_credentials required"
import os

import nkeys

Expand Down Expand Up @@ -634,12 +637,12 @@ def _setup_nkeys_seed_connect(self) -> None:
import nkeys

def _get_nkeys_seed() -> nkeys.KeyPair:
import os

if self._nkeys_seed_str:
seed = bytearray(self._nkeys_seed_str.encode())
else:
creds = self._nkeys_seed
if creds is None:
raise ValueError("cannot extract nkeys seed")
with open(creds, "rb") as f:
seed = bytearray(os.fstat(f.fileno()).st_size)
f.readinto(seed) # type: ignore[attr-defined]
Expand Down Expand Up @@ -670,13 +673,12 @@ async def close(self) -> None:
sets the client to be in the CLOSED state.
No further reconnections occur once reaching this point.
"""
await self._close(Client.CLOSED)
await self._close(ClientState.CLOSED)

async def _close(self, status: int, do_cbs: bool = True) -> None:
async def _close(self, status: ClientState, do_cbs: bool = True) -> None:
if self.is_closed:
self._status = status
return
self._status = Client.CLOSED

# Kick the flusher once again so that Task breaks and avoid pending futures.
await self._flush_pending()
Expand Down Expand Up @@ -748,6 +750,8 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
if self._closed_cb is not None:
await self._closed_cb()

self._status = ClientState.CLOSED

# Set the client_id and subscription prefix back to None
self._client_id = None
self._resp_sub_prefix = None
Expand Down Expand Up @@ -780,7 +784,7 @@ async def drain(self) -> None:
# Relinquish CPU to allow drain tasks to start in the background,
# before setting state to draining.
await asyncio.sleep(0)
self._status = Client.DRAINING_SUBS
self._status = ClientState.DRAINING_SUBS

try:
await asyncio.wait_for(
Expand All @@ -793,9 +797,9 @@ async def drain(self) -> None:
except asyncio.CancelledError:
pass
finally:
self._status = Client.DRAINING_PUBS
self._status = ClientState.DRAINING_PUBS
await self.flush()
await self._close(Client.CLOSED)
await self._close(ClientState.CLOSED)

async def publish(
self,
Expand Down Expand Up @@ -1180,30 +1184,30 @@ def pending_data_size(self) -> int:

@property
def is_closed(self) -> bool:
return self._status == Client.CLOSED
return self._status == ClientState.CLOSED

@property
def is_reconnecting(self) -> bool:
return self._status == Client.RECONNECTING
return self._status == ClientState.RECONNECTING

@property
def is_connected(self) -> bool:
return (self._status == Client.CONNECTED) or self.is_draining
return (self._status == ClientState.CONNECTED) or self.is_draining

@property
def is_connecting(self) -> bool:
return self._status == Client.CONNECTING
return self._status == ClientState.CONNECTING

@property
def is_draining(self) -> bool:
return (
self._status == Client.DRAINING_SUBS
or self._status == Client.DRAINING_PUBS
self._status == ClientState.DRAINING_SUBS
or self._status == ClientState.DRAINING_PUBS
)

@property
def is_draining_pubs(self) -> bool:
return self._status == Client.DRAINING_PUBS
return self._status == ClientState.DRAINING_PUBS

@property
def connected_server_version(self) -> ServerVersion:
Expand Down Expand Up @@ -1261,7 +1265,7 @@ async def _flush_pending(
except asyncio.CancelledError:
pass

def _setup_server_pool(self, connect_url: Union[List[str]]) -> None:
def _setup_server_pool(self, connect_url: Union[str | List[str]]) -> None:
if isinstance(connect_url, str):
try:
if "nats://" in connect_url or "tls://" in connect_url:
Expand Down Expand Up @@ -1393,7 +1397,7 @@ async def _process_err(self, err_msg: str) -> None:
# FIXME: Some errors such as 'Invalid Subscription'
# do not cause the server to close the connection.
# For now we handle similar as other clients and close.
asyncio.create_task(self._close(Client.CLOSED, do_cbs))
asyncio.create_task(self._close(ClientState.CLOSED, do_cbs))

async def _process_op_err(self, e: Exception) -> None:
"""
Expand All @@ -1406,7 +1410,7 @@ async def _process_op_err(self, e: Exception) -> None:
return

if self.options["allow_reconnect"] and self.is_connected:
self._status = Client.RECONNECTING
self._status = ClientState.RECONNECTING
self._ps.reset()

if (self._reconnection_task is not None
Expand All @@ -1420,7 +1424,7 @@ async def _process_op_err(self, e: Exception) -> None:
else:
self._process_disconnect()
self._err = e
await self._close(Client.CLOSED, True)
await self._close(ClientState.CLOSED, True)

async def _attempt_reconnect(self) -> None:
assert self._current_server, "Client.connect must be called first"
Expand Down Expand Up @@ -1506,7 +1510,7 @@ async def _attempt_reconnect(self) -> None:
# to bail earlier in case there are errors in the connection.
# await self._flush_pending(force_flush=True)
await self._flush_pending()
self._status = Client.CONNECTED
self._status = ClientState.CONNECTED
await self.flush()
if self._reconnected_cb is not None:
await self._reconnected_cb()
Expand All @@ -1519,7 +1523,7 @@ async def _attempt_reconnect(self) -> None:
except (OSError, errors.Error, asyncio.TimeoutError) as e:
self._err = e
await self._error_cb(e)
self._status = Client.RECONNECTING
self._status = ClientState.RECONNECTING
self._current_server.last_attempt = time.monotonic()
self._current_server.reconnects += 1
except asyncio.CancelledError:
Expand Down Expand Up @@ -1582,9 +1586,8 @@ def _connect_command(self) -> bytes:
return b"".join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_])

async def _process_ping(self) -> None:
"""
Process PING sent by server.
"""
"""Process PING sent by server."""

await self._send_command(PONG)
await self._flush_pending()

Expand All @@ -1611,7 +1614,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]:
if not headers:
return None

hdr: Optional[Dict[str, str]] = None
hdr: Dict[str, str] = {}
raw_headers = headers[NATS_HDR_LINE_SIZE:]

# If the first character is an empty space, then this is
Expand Down Expand Up @@ -1642,7 +1645,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]:
i = raw_headers.find(_CRLF_)
raw_headers = raw_headers[i + _CRLF_LEN_:]

if len(desc) > 0:
if len(desc):
# Heartbeat messages can have both headers and inline status,
# check that there are no pending headers to be parsed.
i = desc.find(_CRLF_)
Expand All @@ -1657,7 +1660,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]:
# Just inline status...
hdr[nats.js.api.Header.DESCRIPTION] = desc.decode()

if not len(raw_headers) > _CRLF_LEN_:
if len(raw_headers) <= _CRLF_LEN_:
return hdr

#
Expand Down Expand Up @@ -1850,7 +1853,7 @@ def _process_disconnect(self) -> None:
Process disconnection from the server and set client status
to DISCONNECTED.
"""
self._status = Client.DISCONNECTED
self._status = ClientState.DISCONNECTED

def _process_info(
self, info: Dict[str, Any], initial_connection: bool = False
Expand Down Expand Up @@ -1914,7 +1917,7 @@ async def _process_connect_init(self) -> None:
"""
assert self._transport, "must be called only from Client.connect"
assert self._current_server, "must be called only from Client.connect"
self._status = Client.CONNECTING
self._status = ClientState.CONNECTING

# Check whether to reuse the original hostname for an implicit route.
hostname = None
Expand Down Expand Up @@ -2015,7 +2018,7 @@ async def _process_connect_init(self) -> None:
)

if PONG_PROTO in next_op:
self._status = Client.CONNECTED
self._status = ClientState.CONNECTED
elif ERR_OP in next_op:
err_line = next_op.decode()
_, err_msg = err_line.split(" ", 1)
Expand All @@ -2026,7 +2029,7 @@ async def _process_connect_init(self) -> None:
raise errors.Error("nats: " + err_msg.rstrip("\r\n"))

if PONG_PROTO in next_op:
self._status = Client.CONNECTED
self._status = ClientState.CONNECTED

self._reading_task = asyncio.get_running_loop().create_task(
self._read_loop()
Expand Down Expand Up @@ -2139,7 +2142,7 @@ async def __aenter__(self) -> "Client":

async def __aexit__(self, *exc_info) -> None:
"""Close connection to NATS when used in a context manager"""
await self._close(Client.CLOSED, do_cbs=True)
await self._close(ClientState.CLOSED, do_cbs=True)

def jetstream(self, **opts) -> nats.js.JetStreamContext:
"""
Expand Down
2 changes: 1 addition & 1 deletion nats/js/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class StreamsListIterator(Iterable):
"""

def __init__(
self, offset: int, total: int, streams: List[Dict[str, any]]
self, offset: int, total: int, streams: List[Dict[str, Any]]
) -> None:
self.offset = offset
self.total = total
Expand Down
Loading
Loading