From 618d5c8d35f0fdf5e6215eccdd31f59e03181cb5 Mon Sep 17 00:00:00 2001 From: Casper Beyer Date: Wed, 25 Sep 2024 13:09:33 +0200 Subject: [PATCH] Enforce formatting --- .github/workflows/check.yml | 30 + .github/workflows/test.yml | 2 +- Makefile | 2 +- benchmark/inbox_perf.py | 10 +- benchmark/latency_perf.py | 111 ++-- benchmark/parser_perf.py | 37 +- benchmark/pub_perf.py | 46 +- benchmark/pub_sub_perf.py | 54 +- benchmark/sub_perf.py | 20 +- examples/advanced.py | 34 +- examples/basic.py | 25 +- examples/client.py | 21 +- examples/clustered.py | 13 +- examples/common/args.py | 25 +- examples/component.py | 26 +- examples/connect.py | 40 +- examples/context-manager.py | 14 +- examples/drain-sub.py | 34 +- examples/example.py | 22 +- examples/jetstream.py | 4 +- examples/kv.py | 21 +- examples/micro/service.py | 3 +- examples/nats-pub/__main__.py | 20 +- examples/nats-req/__main__.py | 14 +- examples/nats-sub/__main__.py | 20 +- examples/publish.py | 16 +- examples/service.py | 30 +- examples/simple.py | 5 +- examples/subscribe.py | 25 +- examples/tls.py | 40 +- examples/wildcard.py | 17 +- nats/aio/client.py | 219 ++++--- nats/aio/errors.py | 20 + nats/aio/msg.py | 40 +- nats/aio/subscription.py | 17 +- nats/aio/transport.py | 13 +- nats/errors.py | 6 +- nats/js/api.py | 123 ++-- nats/js/client.py | 98 +-- nats/js/errors.py | 21 +- nats/js/kv.py | 11 +- nats/js/manager.py | 103 +-- nats/js/object_store.py | 36 +- nats/micro/request.py | 18 +- nats/micro/service.py | 106 +++- nats/nuid.py | 2 +- nats/protocol/command.py | 28 +- nats/protocol/parser.py | 53 +- setup.cfg | 1 + setup.py | 14 +- tests/test_client.py | 703 ++++++++++----------- tests/test_client_async_await.py | 66 +- tests/test_client_nkeys.py | 49 +- tests/test_client_v2.py | 49 +- tests/test_client_websocket.py | 110 ++-- tests/test_compatibility.py | 10 +- tests/test_js.py | 1010 ++++++++++++++++-------------- tests/test_micro_service.py | 293 +++++---- tests/test_parser.py | 58 +- tests/utils.py | 87 +-- 60 files changed, 2302 insertions(+), 1843 deletions(-) create mode 100644 .github/workflows/check.yml diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml new file mode 100644 index 00000000..8489e4e3 --- /dev/null +++ b/.github/workflows/check.yml @@ -0,0 +1,30 @@ +name: Check +on: + push: + branches: + - main + pull_request: + branches: + - "*" +jobs: + format: + runs-on: ubuntu-latest + name: Format + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pipenv + pipenv install --dev + + - name: Run format check + run: | + pipenv run yapf --diff --recursive . diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c659d315..4164f9e2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,7 @@ jobs: - name: Run tests run: | - pipenv run flake8 --ignore=W391 ./nats/js/ + pipenv run flake8 --ignore="W391, W503, W504" ./nats/js/ pipenv run pytest -x -vv -s --continue-on-collection-errors env: PATH: $HOME/nats-server:$PATH diff --git a/Makefile b/Makefile index 1d94698e..d5807e7e 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ ci: deps # pipenv run yapf --recursive --diff $(SOURCE_CODE) # pipenv run yapf --recursive --diff tests # pipenv run mypy - pipenv run flake8 --ignore=W391 ./nats/js/ + pipenv run flake8 --ignore="W391, W503, W504" ./nats/js/ pipenv run pytest -x -vv -s --continue-on-collection-errors watch: diff --git a/benchmark/inbox_perf.py b/benchmark/inbox_perf.py index 0ef756ae..a882c39b 100644 --- a/benchmark/inbox_perf.py +++ b/benchmark/inbox_perf.py @@ -2,21 +2,23 @@ from nats.nuid import NUID -INBOX_PREFIX = bytearray(b'_INBOX.') +INBOX_PREFIX = bytearray(b"_INBOX.") + def gen_inboxes_nuid(n): nuid = NUID() for i in range(0, n): inbox = INBOX_PREFIX[:] inbox.extend(nuid.next()) - inbox.extend(b'.') + inbox.extend(b".") inbox.extend(nuid.next()) -if __name__ == '__main__': + +if __name__ == "__main__": benchs = [ "gen_inboxes_nuid(100000)", "gen_inboxes_nuid(1000000)", - ] + ] for bench in benchs: print(f"=== {bench}") prof.run(bench) diff --git a/benchmark/latency_perf.py b/benchmark/latency_perf.py index 7ed3da7b..20ae6c4e 100644 --- a/benchmark/latency_perf.py +++ b/benchmark/latency_perf.py @@ -10,66 +10,77 @@ HASH_MODULO = 250 try: - import uvloop - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: - pass + pass + def show_usage(): - message = """ + message = """ Usage: latency_perf [options] options: -n ITERATIONS Iterations to spec (default: 1000) -S SUBJECT Send subject (default: (test) """ - print(message) + print(message) + def show_usage_and_die(): - show_usage() - sys.exit(1) + show_usage() + sys.exit(1) + async def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-n', '--iterations', default=DEFAULT_ITERATIONS, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('--servers', default=[], action='append') - args = parser.parse_args() - - servers = args.servers - if len(args.servers) < 1: - servers = ["nats://127.0.0.1:4222"] - - try: - nc = await nats.connect(servers) - except Exception as e: - sys.stderr.write(f"ERROR: {e}") - show_usage_and_die() - - async def handler(msg): - await nc.publish(msg.reply, b'') - await nc.subscribe(args.subject, cb=handler) - - # Start the benchmark - start = time.monotonic() - to_send = args.iterations - - print("Sending {} request/responses on [{}]".format( - args.iterations, args.subject)) - while to_send > 0: - to_send -= 1 - if to_send == 0: - break - - await nc.request(args.subject, b'') - if (to_send % HASH_MODULO) == 0: - sys.stdout.write("#") - sys.stdout.flush() - - duration = time.monotonic() - start - ms = "%.3f" % ((duration/args.iterations) * 1000) - print(f"\nTest completed : {ms} ms avg request/response latency") - await nc.close() - -if __name__ == '__main__': - asyncio.run(main()) + parser = argparse.ArgumentParser() + parser.add_argument( + "-n", "--iterations", default=DEFAULT_ITERATIONS, type=int + ) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("--servers", default=[], action="append") + args = parser.parse_args() + + servers = args.servers + if len(args.servers) < 1: + servers = ["nats://127.0.0.1:4222"] + + try: + nc = await nats.connect(servers) + except Exception as e: + sys.stderr.write(f"ERROR: {e}") + show_usage_and_die() + + async def handler(msg): + await nc.publish(msg.reply, b"") + + await nc.subscribe(args.subject, cb=handler) + + # Start the benchmark + start = time.monotonic() + to_send = args.iterations + + print( + "Sending {} request/responses on [{}]".format( + args.iterations, args.subject + ) + ) + while to_send > 0: + to_send -= 1 + if to_send == 0: + break + + await nc.request(args.subject, b"") + if (to_send % HASH_MODULO) == 0: + sys.stdout.write("#") + sys.stdout.flush() + + duration = time.monotonic() - start + ms = "%.3f" % ((duration / args.iterations) * 1000) + print(f"\nTest completed : {ms} ms avg request/response latency") + await nc.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/parser_perf.py b/benchmark/parser_perf.py index 9ef0f53b..11e86963 100644 --- a/benchmark/parser_perf.py +++ b/benchmark/parser_perf.py @@ -11,15 +11,15 @@ def __init__(self): self._pongs = [] self._pings_outstanding = 0 self._pongs_received = 0 - self._server_info = {"max_payload": 1048576, "auth_required": False } + self._server_info = {"max_payload": 1048576, "auth_required": False} self.stats = { - 'in_msgs': 0, - 'out_msgs': 0, - 'in_bytes': 0, - 'out_bytes': 0, - 'reconnects': 0, - 'errors_received': 0 - } + "in_msgs": 0, + "out_msgs": 0, + "in_bytes": 0, + "out_bytes": 0, + "reconnects": 0, + "errors_received": 0, + } async def _send_command(self, cmd): pass @@ -31,31 +31,34 @@ async def _process_ping(self): pass async def _process_msg(self, sid, subject, reply, data): - self.stats['in_msgs'] += 1 - self.stats['in_bytes'] += len(data) + self.stats["in_msgs"] += 1 + self.stats["in_bytes"] += len(data) async def _process_err(self, err=None): pass + def generate_msg(subject, nbytes, reply=""): msg = [] protocol_line = "MSG {subject} 1 {reply} {nbytes}\r\n".format( - subject=subject, reply=reply, nbytes=nbytes).encode() + subject=subject, reply=reply, nbytes=nbytes + ).encode() msg.append(protocol_line) - msg.append(b'A' * nbytes) - msg.append(b'r\n') - return b''.join(msg) + msg.append(b"A" * nbytes) + msg.append(b"r\n") + return b"".join(msg) + def parse_msgs(max_msgs=1, nbytes=1): - buf = b''.join([generate_msg("foo", nbytes) for i in range(0, max_msgs)]) + buf = b"".join([generate_msg("foo", nbytes) for i in range(0, max_msgs)]) print("--- buffer size: {}".format(len(buf))) loop = asyncio.get_event_loop() ps = Parser(DummyNatsClient()) loop.run_until_complete(ps.parse(buf)) print("--- stats: ", ps.nc.stats) -if __name__ == '__main__': +if __name__ == "__main__": benchs = [ "parse_msgs(max_msgs=10000, nbytes=1)", "parse_msgs(max_msgs=100000, nbytes=1)", @@ -71,7 +74,7 @@ def parse_msgs(max_msgs=1, nbytes=1): "parse_msgs(max_msgs=100000, nbytes=8192)", "parse_msgs(max_msgs=10000, nbytes=16384)", "parse_msgs(max_msgs=100000, nbytes=16384)", - ] + ] for bench in benchs: print(f"=== {bench}") diff --git a/benchmark/pub_perf.py b/benchmark/pub_perf.py index b90065c1..6e2c4f11 100644 --- a/benchmark/pub_perf.py +++ b/benchmark/pub_perf.py @@ -7,10 +7,11 @@ import nats try: - import uvloop - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: - pass + pass DEFAULT_FLUSH_TIMEOUT = 30 DEFAULT_NUM_MSGS = 100000 @@ -18,6 +19,7 @@ DEFAULT_BATCH_SIZE = 100 HASH_MODULO = 1000 + def show_usage(): message = """ Usage: pub_perf [options] @@ -30,24 +32,26 @@ def show_usage(): """ print(message) + def show_usage_and_die(): show_usage() sys.exit(1) + async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-n', '--count', default=DEFAULT_NUM_MSGS, type=int) - parser.add_argument('-s', '--size', default=DEFAULT_MSG_SIZE, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('-b', '--batch', default=DEFAULT_BATCH_SIZE, type=int) - parser.add_argument('--servers', default=[], action='append') + parser.add_argument("-n", "--count", default=DEFAULT_NUM_MSGS, type=int) + parser.add_argument("-s", "--size", default=DEFAULT_MSG_SIZE, type=int) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("-b", "--batch", default=DEFAULT_BATCH_SIZE, type=int) + parser.add_argument("--servers", default=[], action="append") args = parser.parse_args() data = [] for i in range(0, args.size): s = "%01x" % randint(0, 15) data.append(s.encode()) - payload = b''.join(data) + payload = b"".join(data) servers = args.servers if len(args.servers) < 1: @@ -55,7 +59,7 @@ async def main(): # Make sure we're connected to a server first.. try: - nc = await nats.connect(servers, pending_size=1024*1024) + nc = await nats.connect(servers, pending_size=1024 * 1024) except Exception as e: sys.stderr.write(f"ERROR: {e}") show_usage_and_die() @@ -64,8 +68,11 @@ async def main(): start = time.time() to_send = args.count - print("Sending {} messages of size {} bytes on [{}]".format( - args.count, args.size, args.subject)) + print( + "Sending {} messages of size {} bytes on [{}]".format( + args.count, args.size, args.subject + ) + ) while to_send > 0: for i in range(0, args.batch): to_send -= 1 @@ -86,11 +93,14 @@ async def main(): print(f"Server flush timeout after {DEFAULT_FLUSH_TIMEOUT}") elapsed = time.time() - start - mbytes = "%.1f" % (((args.size * args.count)/elapsed) / (1024*1024)) - print("\nTest completed : {} msgs/sec ({}) MB/sec".format( - args.count/elapsed, - mbytes)) + mbytes = "%.1f" % (((args.size * args.count) / elapsed) / (1024 * 1024)) + print( + "\nTest completed : {} msgs/sec ({}) MB/sec".format( + args.count / elapsed, mbytes + ) + ) await nc.close() -if __name__ == '__main__': - asyncio.run(main()) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/pub_sub_perf.py b/benchmark/pub_sub_perf.py index 996b5a76..87198b14 100644 --- a/benchmark/pub_sub_perf.py +++ b/benchmark/pub_sub_perf.py @@ -7,10 +7,11 @@ import nats try: - import uvloop - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: - pass + pass DEFAULT_FLUSH_TIMEOUT = 30 DEFAULT_NUM_MSGS = 100000 @@ -18,6 +19,7 @@ DEFAULT_BATCH_SIZE = 100 HASH_MODULO = 1000 + def show_usage(): message = """ Usage: pub_sub_perf [options] @@ -30,24 +32,26 @@ def show_usage(): """ print(message) + def show_usage_and_die(): show_usage() sys.exit(1) + async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-n', '--count', default=DEFAULT_NUM_MSGS, type=int) - parser.add_argument('-s', '--size', default=DEFAULT_MSG_SIZE, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('-b', '--batch', default=DEFAULT_BATCH_SIZE, type=int) - parser.add_argument('--servers', default=[], action='append') + parser.add_argument("-n", "--count", default=DEFAULT_NUM_MSGS, type=int) + parser.add_argument("-s", "--size", default=DEFAULT_MSG_SIZE, type=int) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("-b", "--batch", default=DEFAULT_BATCH_SIZE, type=int) + parser.add_argument("--servers", default=[], action="append") args = parser.parse_args() data = [] for i in range(0, args.size): s = "%01x" % randint(0, 15) data.append(s.encode()) - payload = b''.join(data) + payload = b"".join(data) servers = args.servers if len(args.servers) < 1: @@ -61,20 +65,25 @@ async def main(): show_usage_and_die() received = 0 + async def handler(msg): nonlocal received received += 1 if (received % HASH_MODULO) == 0: sys.stdout.write("*") sys.stdout.flush() + await nc.subscribe(args.subject, cb=handler) # Start the benchmark start = time.time() to_send = args.count - print("Sending {} messages of size {} bytes on [{}]".format( - args.count, args.size, args.subject)) + print( + "Sending {} messages of size {} bytes on [{}]".format( + args.count, args.size, args.subject + ) + ) while to_send > 0: for i in range(0, args.batch): to_send -= 1 @@ -97,13 +106,20 @@ async def handler(msg): print(f"Server flush timeout after {DEFAULT_FLUSH_TIMEOUT}") elapsed = time.time() - start - mbytes = "%.1f" % (((args.size * args.count)/elapsed) / (1024*1024)) - print("\nTest completed : {} msgs/sec sent ({}) MB/sec".format( - args.count/elapsed, - mbytes)) - - print("Received {} messages ({} msgs/sec)".format(received, received/elapsed)) + mbytes = "%.1f" % (((args.size * args.count) / elapsed) / (1024 * 1024)) + print( + "\nTest completed : {} msgs/sec sent ({}) MB/sec".format( + args.count / elapsed, mbytes + ) + ) + + print( + "Received {} messages ({} msgs/sec)".format( + received, received / elapsed + ) + ) await nc.close() -if __name__ == '__main__': - asyncio.run(main()) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/sub_perf.py b/benchmark/sub_perf.py index 8ca47697..1c68c49d 100644 --- a/benchmark/sub_perf.py +++ b/benchmark/sub_perf.py @@ -12,6 +12,7 @@ DEFAULT_BATCH_SIZE = 100 HASH_MODULO = 1000 + def show_usage(): message = """ Usage: sub_perf [options] @@ -22,15 +23,17 @@ def show_usage(): """ print(message) + def show_usage_and_die(): show_usage() sys.exit(1) + async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-n', '--count', default=DEFAULT_NUM_MSGS, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('--servers', default=[], action='append') + parser.add_argument("-n", "--count", default=DEFAULT_NUM_MSGS, type=int) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("--servers", default=[], action="append") args = parser.parse_args() servers = args.servers @@ -73,10 +76,15 @@ async def handler(msg): await asyncio.sleep(0.1) elapsed = time.monotonic() - start - print("\nTest completed : {} msgs/sec sent".format(args.count/elapsed)) + print("\nTest completed : {} msgs/sec sent".format(args.count / elapsed)) - print("Received {} messages ({} msgs/sec)".format(received, received/elapsed)) + print( + "Received {} messages ({} msgs/sec)".format( + received, received / elapsed + ) + ) await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/advanced.py b/examples/advanced.py index 38e6ac8e..47aa136f 100644 --- a/examples/advanced.py +++ b/examples/advanced.py @@ -4,26 +4,28 @@ async def main(): + async def disconnected_cb(): - print('Got disconnected!') + print("Got disconnected!") async def reconnected_cb(): - print(f'Got reconnected to {nc.connected_url.netloc}') + print(f"Got reconnected to {nc.connected_url.netloc}") async def error_cb(e): - print(f'There was an error: {e}') + print(f"There was an error: {e}") async def closed_cb(): - print('Connection is closed') + print("Connection is closed") try: # Setting explicit list of servers in a cluster. - nc = await nats.connect("localhost:4222", - error_cb=error_cb, - reconnected_cb=reconnected_cb, - disconnected_cb=disconnected_cb, - closed_cb=closed_cb, - ) + nc = await nats.connect( + "localhost:4222", + error_cb=error_cb, + reconnected_cb=reconnected_cb, + disconnected_cb=disconnected_cb, + closed_cb=closed_cb, + ) except NoServersError as e: print(e) return @@ -38,11 +40,14 @@ async def request_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # Signal the server to stop sending messages after we got 10 already. - resp = await nc.request("help.please", b'help') + resp = await nc.request("help.please", b"help") print("Response:", resp) try: @@ -58,5 +63,6 @@ async def request_handler(msg): # handle any pending messages inflight that the server may have sent. await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/basic.py b/examples/basic.py index 5f5a6f07..9099ef1f 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -19,25 +19,29 @@ async def message_handler(msg): # Stop receiving after 2 messages. await sub.unsubscribe(limit=2) - await nc.publish("foo", b'Hello') - await nc.publish("foo", b'World') - await nc.publish("foo", b'!!!!!') + await nc.publish("foo", b"Hello") + await nc.publish("foo", b"World") + await nc.publish("foo", b"!!!!!") # Synchronous style with iterator also supported. sub = await nc.subscribe("bar") - await nc.publish("bar", b'First') - await nc.publish("bar", b'Second') + await nc.publish("bar", b"First") + await nc.publish("bar", b"Second") try: async for msg in sub.messages: - print(f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}") + print( + f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}" + ) await sub.unsubscribe() except Exception as e: pass async def help_request(msg): - print(f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}") - await nc.publish(msg.reply, b'I can help') + print( + f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}" + ) + await nc.publish(msg.reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -46,7 +50,7 @@ async def help_request(msg): # Send a request and expect a single response # and trigger timeout if not faster than 500 ms. try: - response = await nc.request("help", b'help me', timeout=0.5) + response = await nc.request("help", b"help me", timeout=0.5) print(f"Received response: {response.data.decode()}") except TimeoutError: print("Request timed out") @@ -57,5 +61,6 @@ async def help_request(msg): # Terminate connection to NATS. await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/client.py b/examples/client.py index ba59287d..e316d4b1 100644 --- a/examples/client.py +++ b/examples/client.py @@ -6,6 +6,7 @@ class Client: + def __init__(self, nc): self.nc = nc @@ -13,8 +14,11 @@ async def message_handler(self, msg): print(f"[Received on '{msg.subject}']: {msg.data.decode()}") async def request_handler(self, msg): - print("[Request on '{} {}']: {}".format(msg.subject, msg.reply, - msg.data.decode())) + print( + "[Request on '{} {}']: {}".format( + msg.subject, msg.reply, msg.data.decode() + ) + ) await self.nc.publish(msg.reply, b"I can help!") async def start(self): @@ -30,17 +34,16 @@ async def start(self): sub = await nc.subscribe("discover", "", self.message_handler) await sub.unsubscribe(2) - await nc.publish("discover", b'hello') - await nc.publish("discover", b'world') + await nc.publish("discover", b"hello") + await nc.publish("discover", b"world") # Following 2 messages won't be received. - await nc.publish("discover", b'again') - await nc.publish("discover", b'!!!!!') + await nc.publish("discover", b"again") + await nc.publish("discover", b"!!!!!") except ConnectionClosedError: print("Connection closed prematurely") if nc.is_connected: - # Subscription using a 'workers' queue so that only a single subscriber # gets a request at a time. await nc.subscribe("help", "workers", self.request_handler) @@ -49,7 +52,7 @@ async def start(self): # Make a request expecting a single response within 500 ms, # otherwise raising a timeout error. start_time = datetime.now() - response = await nc.request("help", b'help please', 0.500) + response = await nc.request("help", b"help please", 0.500) end_time = datetime.now() print(f"[Response]: {response.data}") print("[Duration]: {}".format(end_time - start_time)) @@ -73,6 +76,6 @@ async def start(self): print("Disconnected.") -if __name__ == '__main__': +if __name__ == "__main__": c = Client(NATS()) asyncio.run(c.start()) diff --git a/examples/clustered.py b/examples/clustered.py index 8356a117..1070efa6 100644 --- a/examples/clustered.py +++ b/examples/clustered.py @@ -6,7 +6,6 @@ async def run(): - # Setup pool of servers from a NATS cluster. options = { "servers": [ @@ -65,14 +64,17 @@ async def subscribe_handler(msg): for i in range(0, max_messages): try: - await nc.publish(f"help.{i}", b'A') + await nc.publish(f"help.{i}", b"A") await nc.flush(0.500) except ErrConnectionClosed as e: print("Connection closed prematurely.") break except ErrTimeout as e: - print("Timeout occurred when publishing msg i={}: {}".format( - i, e)) + print( + "Timeout occurred when publishing msg i={}: {}".format( + i, e + ) + ) end_time = datetime.now() await nc.drain() @@ -88,7 +90,8 @@ async def subscribe_handler(msg): if err is not None: print(f"Last Error: {err}") -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(run()) loop.close() diff --git a/examples/common/args.py b/examples/common/args.py index 71ef938e..530b9e7e 100644 --- a/examples/common/args.py +++ b/examples/common/args.py @@ -3,16 +3,19 @@ def get_args(description, epilog=""): parser = argparse.ArgumentParser(description=description, epilog=epilog) - parser.add_argument("-s", "--servers", nargs='*', - default=["nats://localhost:4222"], - help="List of servers to connect. A demo server " - "is available at nats://demo.nats.io:4222. " - "However, it is very likely that the demo server " - "will see traffic from clients other than yours." - "To avoid this, start your own locally and pass " - "its address. " - "This option is default set to nats://localhost:4222 " - "which mean that, by default, you should run your " - "own server locally" + parser.add_argument( + "-s", + "--servers", + nargs="*", + default=["nats://localhost:4222"], + help="List of servers to connect. A demo server " + "is available at nats://demo.nats.io:4222. " + "However, it is very likely that the demo server " + "will see traffic from clients other than yours." + "To avoid this, start your own locally and pass " + "its address. " + "This option is default set to nats://localhost:4222 " + "which mean that, by default, you should run your " + "own server locally", ) return parser.parse_known_args() diff --git a/examples/component.py b/examples/component.py index 0cff75a8..6459aa0d 100644 --- a/examples/component.py +++ b/examples/component.py @@ -3,7 +3,9 @@ import signal import nats + class Component: + def __init__(self): self._nc = None self._done = asyncio.Future() @@ -21,9 +23,9 @@ async def close(self): asyncio.get_running_loop().stop() async def subscribe(self, *args, **kwargs): - if 'max_tasks' in kwargs: - sem = asyncio.Semaphore(kwargs['max_tasks']) - callback = kwargs['cb'] + if "max_tasks" in kwargs: + sem = asyncio.Semaphore(kwargs["max_tasks"]) + callback = kwargs["cb"] async def release(msg): try: @@ -35,16 +37,17 @@ async def cb(msg): await sem.acquire() asyncio.create_task(release(msg)) - kwargs['cb'] = cb - + kwargs["cb"] = cb + # Add some concurrency to the handler. - del kwargs['max_tasks'] + del kwargs["max_tasks"] return await self._nc.subscribe(*args, **kwargs) async def publish(self, *args): await self._nc.publish(*args) + async def main(): c = Component() await c.connect("nats://localhost:4222") @@ -53,7 +56,7 @@ async def handler(msg): # Cause some head of line blocking await asyncio.sleep(0.5) print(time.ctime(), f"Received on 'test': {msg.data.decode()}") - + async def wildcard_handler(msg): print(time.ctime(), f"Received on '>': {msg.data.decode()}") @@ -84,12 +87,15 @@ def signal_handler(): task.cancel() asyncio.create_task(c.close()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) await c.run_forever() -if __name__ == '__main__': + +if __name__ == "__main__": try: asyncio.run(main()) except: diff --git a/examples/connect.py b/examples/connect.py index 90bd8f29..96ea0841 100644 --- a/examples/connect.py +++ b/examples/connect.py @@ -4,33 +4,37 @@ async def main(): + async def disconnected_cb(): - print('Got disconnected!') + print("Got disconnected!") async def reconnected_cb(): - print('Got reconnected to NATS...') + print("Got reconnected to NATS...") async def error_cb(e): - print(f'There was an error: {e}') + print(f"There was an error: {e}") async def closed_cb(): - print('Connection is closed') + print("Connection is closed") arguments, _ = args.get_args("Run a connect example.") - nc = await nats.connect(arguments.servers, - error_cb=error_cb, - reconnected_cb=reconnected_cb, - disconnected_cb=disconnected_cb, - closed_cb=closed_cb, - ) + nc = await nats.connect( + arguments.servers, + error_cb=error_cb, + reconnected_cb=reconnected_cb, + disconnected_cb=disconnected_cb, + closed_cb=closed_cb, + ) async def handler(msg): - print(f'Received a message on {msg.subject} {msg.reply}: {msg.data}') - await msg.respond(b'OK') - sub = await nc.subscribe('help.please', cb=handler) - - resp = await nc.request('help.please', b'help') - print('Response:', resp) - -if __name__ == '__main__': + print(f"Received a message on {msg.subject} {msg.reply}: {msg.data}") + await msg.respond(b"OK") + + sub = await nc.subscribe("help.please", cb=handler) + + resp = await nc.request("help.please", b"help") + print("Response:", resp) + + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/context-manager.py b/examples/context-manager.py index 1e86b16d..f11dadb4 100644 --- a/examples/context-manager.py +++ b/examples/context-manager.py @@ -4,7 +4,6 @@ async def main(): - is_done = asyncio.Future() async def closed_cb(): @@ -12,15 +11,19 @@ async def closed_cb(): is_done.set_result(True) arguments, _ = args.get_args("Run a context manager example.") - async with (await nats.connect(arguments.servers, closed_cb=closed_cb)) as nc: + async with await nats.connect(arguments.servers, + closed_cb=closed_cb) as nc: print(f"Connected to NATS at {nc.connected_url.netloc}...") async def subscribe_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) await nc.subscribe("discover", cb=subscribe_handler) await nc.flush() @@ -40,5 +43,6 @@ async def subscribe_handler(msg): await asyncio.wait_for(is_done, 60.0) -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/drain-sub.py b/examples/drain-sub.py index 7cd5fab4..fd58b8d5 100644 --- a/examples/drain-sub.py +++ b/examples/drain-sub.py @@ -13,25 +13,31 @@ async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # Simple publisher and async subscriber via coroutine. sub = await nc.subscribe("foo", cb=message_handler) # Stop receiving after 2 messages. await sub.unsubscribe(2) - await nc.publish("foo", b'Hello') - await nc.publish("foo", b'World') - await nc.publish("foo", b'!!!!!') + await nc.publish("foo", b"Hello") + await nc.publish("foo", b"World") + await nc.publish("foo", b"!!!!!") async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await nc.publish(reply, b'I can help') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -52,7 +58,7 @@ async def drain_sub(): # Send multiple requests and drain the subscription. requests = [] for i in range(0, 100): - request = nc.request("help", b'help me', 0.5) + request = nc.request("help", b"help me", 0.5) requests.append(request) # Wait for all the responses @@ -62,13 +68,17 @@ async def drain_sub(): print("Received {count} responses!".format(count=len(responses))) for response in responses[:5]: - print("Received response: {message}".format( - message=response.data.decode())) + print( + "Received response: {message}".format( + message=response.data.decode() + ) + ) except: pass # Terminate connection to NATS. await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/example.py b/examples/example.py index 91020a1c..a9df0392 100644 --- a/examples/example.py +++ b/examples/example.py @@ -22,22 +22,24 @@ async def message_handler(msg): sub = await nc.subscribe("discover", "", message_handler) await sub.unsubscribe(2) - await nc.publish("discover", b'hello') - await nc.publish("discover", b'world') + await nc.publish("discover", b"hello") + await nc.publish("discover", b"world") # Following 2 messages won't be received. - await nc.publish("discover", b'again') - await nc.publish("discover", b'!!!!!') + await nc.publish("discover", b"again") + await nc.publish("discover", b"!!!!!") except ConnectionClosedError: print("Connection closed prematurely") async def request_handler(msg): - print("[Request on '{} {}']: {}".format(msg.subject, msg.reply, - msg.data.decode())) - await nc.publish(msg.reply, b'OK') + print( + "[Request on '{} {}']: {}".format( + msg.subject, msg.reply, msg.data.decode() + ) + ) + await nc.publish(msg.reply, b"OK") if nc.is_connected: - # Subscription using a 'workers' queue so that only a single subscriber # gets a request at a time. await nc.subscribe("help", "workers", cb=request_handler) @@ -45,7 +47,7 @@ async def request_handler(msg): try: # Make a request expecting a single response within 500 ms, # otherwise raising a timeout error. - msg = await nc.request("help", b'help please', 0.500) + msg = await nc.request("help", b"help please", 0.500) print(f"[Response]: {msg.data}") # Make a roundtrip to the server to ensure messages @@ -67,5 +69,5 @@ async def request_handler(msg): print("Disconnected.") -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/jetstream.py b/examples/jetstream.py index ca24a78b..3ee43be8 100644 --- a/examples/jetstream.py +++ b/examples/jetstream.py @@ -43,6 +43,7 @@ async def qsub_a(msg): async def qsub_b(msg): print("QSUB B:", msg) await msg.ack() + await js.subscribe("foo", "workers", cb=qsub_a) await js.subscribe("foo", "workers", cb=qsub_b) @@ -65,5 +66,6 @@ async def qsub_b(msg): await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/kv.py b/examples/kv.py index faae5330..e8b54e86 100644 --- a/examples/kv.py +++ b/examples/kv.py @@ -7,25 +7,26 @@ async def main(): js = nc.jetstream() # Create a KV - kv = await js.create_key_value(bucket='MY_KV') + kv = await js.create_key_value(bucket="MY_KV") # Set and retrieve a value - await kv.put('hello', b'world') - await kv.put('goodbye', b'farewell') - await kv.put('greetings', b'hi') - await kv.put('greeting', b'hey') + await kv.put("hello", b"world") + await kv.put("goodbye", b"farewell") + await kv.put("greetings", b"hi") + await kv.put("greeting", b"hey") # Retrieve and print the value of 'hello' - entry = await kv.get('hello') - print(f'KeyValue.Entry: key={entry.key}, value={entry.value}') + entry = await kv.get("hello") + print(f"KeyValue.Entry: key={entry.key}, value={entry.value}") # KeyValue.Entry: key=hello, value=world # Retrieve keys with filters - filtered_keys = await kv.keys(filters=['hello', 'greet']) - print(f'Filtered Keys: {filtered_keys}') + filtered_keys = await kv.keys(filters=["hello", "greet"]) + print(f"Filtered Keys: {filtered_keys}") # Expected Output: ['hello', 'greetings'] await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/micro/service.py b/examples/micro/service.py index da4b8880..0e426c51 100644 --- a/examples/micro/service.py +++ b/examples/micro/service.py @@ -27,7 +27,8 @@ async def main(): # Add the service service = await stack.enter_async_context( - await nats.micro.add_service(nc, name="demo_service", version="0.0.1") + await + nats.micro.add_service(nc, name="demo_service", version="0.0.1") ) group = service.add_group(name="demo") diff --git a/examples/nats-pub/__main__.py b/examples/nats-pub/__main__.py index 42432556..e57823ef 100644 --- a/examples/nats-pub/__main__.py +++ b/examples/nats-pub/__main__.py @@ -38,11 +38,11 @@ async def run(): parser = argparse.ArgumentParser() # e.g. nats-pub -s demo.nats.io hello "world" - parser.add_argument('subject', default='hello', nargs='?') - parser.add_argument('-d', '--data', default="hello world") - parser.add_argument('-s', '--servers', default="") - parser.add_argument('--creds', default="") - parser.add_argument('--token', default="") + parser.add_argument("subject", default="hello", nargs="?") + parser.add_argument("-d", "--data", default="hello world") + parser.add_argument("-s", "--servers", default="") + parser.add_argument("--creds", default="") + parser.add_argument("--token", default="") args, unknown = parser.parse_known_args() data = args.data @@ -55,10 +55,7 @@ async def error_cb(e): async def reconnected_cb(): print("Got reconnected to NATS...") - options = { - "error_cb": error_cb, - "reconnected_cb": reconnected_cb - } + options = {"error_cb": error_cb, "reconnected_cb": reconnected_cb} if len(args.creds) > 0: options["user_credentials"] = args.creds @@ -68,7 +65,7 @@ async def reconnected_cb(): try: if len(args.servers) > 0: - options['servers'] = args.servers + options["servers"] = args.servers nc = await nats.connect(**options) except Exception as e: @@ -80,7 +77,8 @@ async def reconnected_cb(): await nc.flush() await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() try: loop.run_until_complete(run()) diff --git a/examples/nats-req/__main__.py b/examples/nats-req/__main__.py index d1d59432..dd69806e 100644 --- a/examples/nats-req/__main__.py +++ b/examples/nats-req/__main__.py @@ -38,10 +38,10 @@ async def run(): parser = argparse.ArgumentParser() # e.g. nats-req -s demo.nats.io hello world - parser.add_argument('subject', default='hello', nargs='?') - parser.add_argument('-d', '--data', default="hello world") - parser.add_argument('-s', '--servers', default="") - parser.add_argument('--creds', default="") + parser.add_argument("subject", default="hello", nargs="?") + parser.add_argument("-d", "--data", default="hello world") + parser.add_argument("-s", "--servers", default="") + parser.add_argument("--creds", default="") args, unknown = parser.parse_known_args() data = args.data @@ -62,7 +62,7 @@ async def reconnected_cb(): nc = None try: if len(args.servers) > 0: - options['servers'] = args.servers + options["servers"] = args.servers nc = await nats.connect(**options) except Exception as e: @@ -70,7 +70,7 @@ async def reconnected_cb(): show_usage_and_die() async def req_callback(msg): - await msg.respond(b'a response') + await msg.respond(b"a response") await nc.subscribe(args.subject, cb=req_callback) @@ -87,7 +87,7 @@ async def req_callback(msg): await nc.drain() -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() try: loop.run_until_complete(run()) diff --git a/examples/nats-sub/__main__.py b/examples/nats-sub/__main__.py index 4ae7136e..13510c8f 100644 --- a/examples/nats-sub/__main__.py +++ b/examples/nats-sub/__main__.py @@ -39,10 +39,10 @@ async def run(): parser = argparse.ArgumentParser() # e.g. nats-sub hello -s nats://127.0.0.1:4222 - parser.add_argument('subject', default='hello', nargs='?') - parser.add_argument('-s', '--servers', default="") - parser.add_argument('-q', '--queue', default="") - parser.add_argument('--creds', default="") + parser.add_argument("subject", default="hello", nargs="?") + parser.add_argument("-s", "--servers", default="") + parser.add_argument("-q", "--queue", default="") + parser.add_argument("--creds", default="") args = parser.parse_args() async def error_cb(e): @@ -69,7 +69,7 @@ async def subscribe_handler(msg): options = { "error_cb": error_cb, "closed_cb": closed_cb, - "reconnected_cb": reconnected_cb + "reconnected_cb": reconnected_cb, } if len(args.creds) > 0: @@ -78,7 +78,7 @@ async def subscribe_handler(msg): nc = None try: if len(args.servers) > 0: - options['servers'] = args.servers + options["servers"] = args.servers nc = await nats.connect(**options) except Exception as e: @@ -92,13 +92,15 @@ def signal_handler(): return asyncio.create_task(nc.drain()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) await nc.subscribe(args.subject, args.queue, subscribe_handler) -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(run()) try: diff --git a/examples/publish.py b/examples/publish.py index b8cc3f21..ce933c57 100644 --- a/examples/publish.py +++ b/examples/publish.py @@ -4,8 +4,9 @@ async def main(): - arguments, _ = args.get_args("Run a publish example.", - "Usage: python examples/publish.py") + arguments, _ = args.get_args( + "Run a publish example.", "Usage: python examples/publish.py" + ) nc = await nats.connect(arguments.servers) # Publish as message with an inbox. @@ -13,13 +14,13 @@ async def main(): sub = await nc.subscribe("hello") # Simple publishing - await nc.publish("hello", b'Hello World!') + await nc.publish("hello", b"Hello World!") # Publish with a reply - await nc.publish("hello", b'Hello World!', reply=inbox) - + await nc.publish("hello", b"Hello World!", reply=inbox) + # Publish with a reply - await nc.publish("hello", b'With Headers', headers={'Foo':'Bar'}) + await nc.publish("hello", b"With Headers", headers={"Foo": "Bar"}) while True: try: @@ -32,5 +33,6 @@ async def main(): print("Data :", msg.data) print("Headers:", msg.header) -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/service.py b/examples/service.py index 70bfe79f..56a7c5f4 100644 --- a/examples/service.py +++ b/examples/service.py @@ -17,8 +17,10 @@ def signal_handler(): asyncio.create_task(nc.close()) asyncio.create_task(stop()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) async def disconnected_cb(): print("Got disconnected...") @@ -26,18 +28,23 @@ async def disconnected_cb(): async def reconnected_cb(): print("Got reconnected...") - await nc.connect("127.0.0.1", - reconnected_cb=reconnected_cb, - disconnected_cb=disconnected_cb, - max_reconnect_attempts=-1) + await nc.connect( + "127.0.0.1", + reconnected_cb=reconnected_cb, + disconnected_cb=disconnected_cb, + max_reconnect_attempts=-1, + ) async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await nc.publish(reply, b'I can help') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -47,12 +54,13 @@ async def help_request(msg): for i in range(1, 1000000): await asyncio.sleep(1) try: - response = await nc.request("help", b'hi') + response = await nc.request("help", b"hi") print(response) except Exception as e: print("Error:", e) -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() try: loop.run_until_complete(main()) diff --git a/examples/simple.py b/examples/simple.py index 220ef1f3..ee7c71cb 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -10,7 +10,7 @@ async def main(): sub = await nc.subscribe("foo") # Publish a message to 'foo' - await nc.publish("foo", b'Hello from Python!') + await nc.publish("foo", b"Hello from Python!") # Process a message msg = await sub.next_msg() @@ -19,5 +19,6 @@ async def main(): # Close NATS connection await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/subscribe.py b/examples/subscribe.py index 64b0fe60..b98ec162 100644 --- a/examples/subscribe.py +++ b/examples/subscribe.py @@ -13,10 +13,7 @@ async def closed_cb(): asyncio.get_running_loop().stop() arguments, _ = args.get_args("Run a subscription example.") - options = { - "servers": arguments.servers, - "closed_cb": closed_cb - } + options = {"servers": arguments.servers, "closed_cb": closed_cb} await nc.connect(**options) print(f"Connected to NATS at {nc.connected_url.netloc}...") @@ -25,9 +22,12 @@ async def subscribe_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await msg.respond(b'I can help!') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await msg.respond(b"I can help!") # Basic subscription to receive all published messages # which are being sent to a single topic 'discover' @@ -43,14 +43,17 @@ def signal_handler(): print("Disconnecting...") asyncio.create_task(nc.close()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) - await nc.request("help", b'help') + await nc.request("help", b"help") await asyncio.sleep(30) -if __name__ == '__main__': + +if __name__ == "__main__": try: asyncio.run(main()) except: diff --git a/examples/tls.py b/examples/tls.py index 26762988..60a81b18 100644 --- a/examples/tls.py +++ b/examples/tls.py @@ -10,35 +10,42 @@ async def main(): # Your local server needs to configured with the server certificate in the ../tests/certs directory. ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) ssl_ctx.minimum_version = ssl.PROTOCOL_TLSv1_2 - ssl_ctx.load_verify_locations('../tests/certs/ca.pem') + ssl_ctx.load_verify_locations("../tests/certs/ca.pem") ssl_ctx.load_cert_chain( - certfile='../tests/certs/client-cert.pem', - keyfile='../tests/certs/client-key.pem') + certfile="../tests/certs/client-cert.pem", + keyfile="../tests/certs/client-key.pem", + ) await nc.connect(servers=["nats://127.0.0.1:4222"], tls=ssl_ctx) async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # Simple publisher and async subscriber via coroutine. sid = await nc.subscribe("foo", cb=message_handler) # Stop receiving after 2 messages. await nc.auto_unsubscribe(sid, 2) - await nc.publish("foo", b'Hello') - await nc.publish("foo", b'World') - await nc.publish("foo", b'!!!!!') + await nc.publish("foo", b"Hello") + await nc.publish("foo", b"World") + await nc.publish("foo", b"!!!!!") async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await nc.publish(reply, b'I can help') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -47,9 +54,12 @@ async def help_request(msg): # Send a request and expect a single response # and trigger timeout if not faster than 50 ms. try: - response = await nc.timed_request("help", b'help me', 0.050) - print("Received response: {message}".format( - message=response.data.decode())) + response = await nc.timed_request("help", b"help me", 0.050) + print( + "Received response: {message}".format( + message=response.data.decode() + ) + ) except TimeoutError: print("Request timed out") @@ -57,7 +67,7 @@ async def help_request(msg): await nc.close() -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(main()) loop.close() diff --git a/examples/wildcard.py b/examples/wildcard.py index e0c308a0..a82e7109 100644 --- a/examples/wildcard.py +++ b/examples/wildcard.py @@ -6,16 +6,20 @@ async def run(loop): nc = NATS() - arguments, _ = args.get_args("Run the wildcard example.", - "Usage: python examples/wildcard.py") + arguments, _ = args.get_args( + "Run the wildcard example.", "Usage: python examples/wildcard.py" + ) await nc.connect(arguments.servers) async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # "*" matches any token, at any level of the subject. await nc.subscribe("foo.*.baz", cb=message_handler) @@ -26,12 +30,13 @@ async def message_handler(msg): await nc.subscribe("foo.>", cb=message_handler) # Matches all of the above. - await nc.publish("foo.bar.baz", b'Hello World') + await nc.publish("foo.bar.baz", b"Hello World") # Gracefully close the connection. await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(run(loop)) loop.close() diff --git a/nats/aio/client.py b/nats/aio/client.py index 76fa0480..c2bfd00b 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -69,26 +69,26 @@ ) from .transport import TcpTransport, Transport, WebSocketTransport -__version__ = '2.9.0' -__lang__ = 'python3' +__version__ = "2.9.0" +__lang__ = "python3" _logger = logging.getLogger(__name__) PROTOCOL = 1 -INFO_OP = b'INFO' -CONNECT_OP = b'CONNECT' -PING_OP = b'PING' -PONG_OP = b'PONG' -OK_OP = b'+OK' -ERR_OP = b'-ERR' -_CRLF_ = b'\r\n' +INFO_OP = b"INFO" +CONNECT_OP = b"CONNECT" +PING_OP = b"PING" +PONG_OP = b"PONG" +OK_OP = b"+OK" +ERR_OP = b"-ERR" +_CRLF_ = b"\r\n" _CRLF_LEN_ = len(_CRLF_) -_SPC_ = b' ' +_SPC_ = b" " _SPC_BYTE_ = 32 EMPTY = "" PING_PROTO = PING_OP + _CRLF_ PONG_PROTO = PONG_OP + _CRLF_ -DEFAULT_INBOX_PREFIX = b'_INBOX' +DEFAULT_INBOX_PREFIX = b"_INBOX" DEFAULT_PENDING_SIZE = 2 * 1024 * 1024 DEFAULT_BUFFER_SIZE = 32768 @@ -103,7 +103,7 @@ DEFAULT_DRAIN_TIMEOUT = 30 # in seconds MAX_CONTROL_LINE_SIZE = 1024 -NATS_HDR_LINE = bytearray(b'NATS/1.0') +NATS_HDR_LINE = bytearray(b"NATS/1.0") NATS_HDR_LINE_SIZE = len(NATS_HDR_LINE) NO_RESPONDERS_STATUS = "503" CTRL_STATUS = "100" @@ -148,10 +148,10 @@ def __init__(self, server_version: str) -> None: # TODO(@orsinium): use cached_property def parse_version(self) -> None: - v = (self._server_version).split('-') + v = (self._server_version).split("-") if len(v) > 1: self._dev_version = v[1] - tokens = v[0].split('.') + tokens = v[0].split(".") n = len(tokens) if n > 1: self._major_version = int(tokens[0]) @@ -182,7 +182,7 @@ def patch(self) -> int: def dev(self) -> str: if not self._dev_version: self.parse_version() - return self._dev_version or '' + return self._dev_version or "" def __repr__(self) -> str: return f"" @@ -193,7 +193,7 @@ async def _default_error_callback(ex: Exception) -> None: Provides a default way to handle async errors if the user does not provide one. """ - _logger.error('nats: encountered error', exc_info=ex) + _logger.error("nats: encountered error", exc_info=ex) class Client: @@ -287,12 +287,12 @@ def __init__(self) -> None: self.options: Dict[str, Any] = {} self.stats = { - 'in_msgs': 0, - 'out_msgs': 0, - 'in_bytes': 0, - 'out_bytes': 0, - 'reconnects': 0, - 'errors_received': 0, + "in_msgs": 0, + "out_msgs": 0, + "in_bytes": 0, + "out_bytes": 0, + "reconnects": 0, + "errors_received": 0, } async def connect( @@ -420,8 +420,13 @@ async def subscribe_handler(msg): """ - for cb in [error_cb, disconnected_cb, closed_cb, reconnected_cb, - discovered_server_cb]: + for cb in [ + error_cb, + disconnected_cb, + closed_cb, + reconnected_cb, + discovered_server_cb, + ]: if cb and not asyncio.iscoroutinefunction(cb): raise errors.InvalidCallbackTypeError @@ -461,12 +466,12 @@ async def subscribe_handler(msg): self.options["token"] = token self.options["connect_timeout"] = connect_timeout self.options["drain_timeout"] = drain_timeout - self.options['tls_handshake_first'] = tls_handshake_first + self.options["tls_handshake_first"] = tls_handshake_first if tls: - self.options['tls'] = tls + self.options["tls"] = tls if tls_hostname: - self.options['tls_hostname'] = tls_hostname + self.options["tls_hostname"] = tls_hostname # Check if the username or password was set in the server URI server_auth_configured = False @@ -478,11 +483,8 @@ async def subscribe_handler(msg): if user or password or token or server_auth_configured: self._auth_configured = True - if ( - self._user_credentials is not None - or self._nkeys_seed is not None - or self._nkeys_seed_str is not None - ): + if (self._user_credentials is not None or self._nkeys_seed is not None + or self._nkeys_seed_str is not None): self._auth_configured = True self._setup_nkeys_connect() @@ -502,7 +504,9 @@ async def subscribe_handler(msg): try: await self._select_next_server() await self._process_connect_init() - assert self._current_server, "the current server must be set by _select_next_server" + assert ( + self._current_server + ), "the current server must be set by _select_next_server" self._current_server.reconnects = 0 break except errors.NoServersError as e: @@ -542,7 +546,7 @@ def _setup_nkeys_jwt_connect(self) -> None: def user_cb() -> bytearray: contents = None - with open(creds[0], 'rb') as f: + with open(creds[0], "rb") as f: contents = bytearray(os.fstat(f.fileno()).st_size) f.readinto(contents) # type: ignore[attr-defined] return contents @@ -551,7 +555,7 @@ def user_cb() -> bytearray: def sig_cb(nonce: str) -> bytes: seed = None - with open(creds[1], 'rb') as f: + with open(creds[1], "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) # type: ignore[attr-defined] kp = nkeys.from_seed(seed) @@ -634,7 +638,9 @@ def get_user_jwt(f): return get_user_jwt(f) def _setup_nkeys_seed_connect(self) -> None: - assert self._nkeys_seed or self._nkeys_seed_str, "Client.connect must be called first" + assert ( + self._nkeys_seed or self._nkeys_seed_str + ), "Client.connect must be called first" import os import nkeys @@ -645,7 +651,7 @@ def _get_nkeys_seed() -> nkeys.KeyPair: seed = bytearray(self._nkeys_seed_str.encode()) else: creds = self._nkeys_seed - with open(creds, 'rb') as f: + with open(creds, "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) # type: ignore[attr-defined] key_pair = nkeys.from_seed(seed) @@ -690,8 +696,8 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: ): self._reading_task.cancel() - if self._ping_interval_task is not None and not self._ping_interval_task.cancelled( - ): + if (self._ping_interval_task is not None + and not self._ping_interval_task.cancelled()): self._ping_interval_task.cancel() if self._flusher_task is not None and not self._flusher_task.cancelled( @@ -704,8 +710,8 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: # Wait for the reconnection task to be done which should be soon. try: - if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled( - ): + if (self._reconnection_task_future is not None + and not self._reconnection_task_future.cancelled()): await asyncio.wait_for( self._reconnection_task_future, self.options["reconnect_time_wait"], @@ -805,9 +811,9 @@ async def drain(self) -> None: async def publish( self, subject: str, - payload: bytes = b'', - reply: str = '', - headers: Optional[Dict[str, str]] = None + payload: bytes = b"", + reply: str = "", + headers: Optional[Dict[str, str]] = None, ) -> None: """ Publishes a NATS message. @@ -861,7 +867,9 @@ async def main(): payload_size = len(payload) if not self.is_connected: - if self._max_pending_size <= 0 or payload_size + self._pending_data_size > self._max_pending_size: + if (self._max_pending_size <= 0 + or payload_size + self._pending_data_size + > self._max_pending_size): # Cannot publish during a reconnection when the buffering is disabled, # or if pending buffer is already full. raise errors.OutboundBufferLimitError @@ -900,15 +908,15 @@ async def _send_publish( # Skip empty keys continue hdr.extend(key.encode()) - hdr.extend(b': ') + hdr.extend(b": ") value = v.strip() hdr.extend(value.encode()) hdr.extend(_CRLF_) hdr.extend(_CRLF_) pub_cmd = prot_command.hpub_cmd(subject, reply, hdr, payload) - self.stats['out_msgs'] += 1 - self.stats['out_bytes'] += payload_size + self.stats["out_msgs"] += 1 + self.stats["out_bytes"] += payload_size await self._send_command(pub_cmd) if self._flush_queue is not None and self._flush_queue.empty(): await self._flush_pending() @@ -931,10 +939,10 @@ async def subscribe( If a callback isn't provided, messages can be retrieved via an asynchronous iterator on the returned subscription object. """ - if not subject or (' ' in subject): + if not subject or (" " in subject): raise errors.BadSubjectError - if queue and (' ' in queue): + if queue and (" " in queue): raise errors.BadSubjectError if self.is_closed: @@ -979,11 +987,11 @@ async def _init_request_sub(self) -> None: self._resp_map = {} self._resp_sub_prefix = self._inbox_prefix[:] - self._resp_sub_prefix.extend(b'.') + self._resp_sub_prefix.extend(b".") self._resp_sub_prefix.extend(self._nuid.next()) - self._resp_sub_prefix.extend(b'.') + self._resp_sub_prefix.extend(b".") resp_mux_subject = self._resp_sub_prefix[:] - resp_mux_subject.extend(b'*') + resp_mux_subject.extend(b"*") await self.subscribe( resp_mux_subject.decode(), cb=self._request_sub_callback ) @@ -1001,7 +1009,7 @@ async def _request_sub_callback(self, msg: Msg) -> None: async def request( self, subject: str, - payload: bytes = b'', + payload: bytes = b"", timeout: float = 0.5, old_style: bool = False, headers: Optional[Dict[str, Any]] = None, @@ -1021,8 +1029,8 @@ async def request( msg = await self._request_new_style( subject, payload, timeout=timeout, headers=headers ) - if msg.headers and msg.headers.get(nats.js.api.Header.STATUS - ) == NO_RESPONDERS_STATUS: + if (msg.headers and msg.headers.get(nats.js.api.Header.STATUS) + == NO_RESPONDERS_STATUS): raise errors.NoRespondersError return msg @@ -1048,7 +1056,9 @@ async def _request_new_style( # Then use the future to get the response. future: asyncio.Future = asyncio.Future() - future.add_done_callback(lambda f: self._resp_map.pop(token.decode(), None)) + future.add_done_callback( + lambda f: self._resp_map.pop(token.decode(), None) + ) self._resp_map[token.decode()] = future # Publish the request @@ -1074,7 +1084,7 @@ def new_inbox(self) -> str: msg = sub.next_msg() """ next_inbox = self._inbox_prefix[:] - next_inbox.extend(b'.') + next_inbox.extend(b".") next_inbox.extend(self._nuid.next()) return next_inbox.decode() @@ -1220,11 +1230,11 @@ def connected_server_version(self) -> ServerVersion: def ssl_context(self) -> ssl.SSLContext: ssl_context: Optional[ssl.SSLContext] = None if "tls" in self.options: - ssl_context = self.options.get('tls') + ssl_context = self.options.get("tls") else: ssl_context = ssl.create_default_context() if ssl_context is None: - raise errors.Error('nats: no ssl context provided') + raise errors.Error("nats: no ssl context provided") return ssl_context async def _send_command(self, cmd: bytes, priority: bool = False) -> None: @@ -1233,7 +1243,8 @@ async def _send_command(self, cmd: bytes, priority: bool = False) -> None: else: self._pending.append(cmd) self._pending_data_size += len(cmd) - if self._max_pending_size > 0 and self._pending_data_size > self._max_pending_size: + if (self._max_pending_size > 0 + and self._pending_data_size > self._max_pending_size): # Only flush force timeout on publish await self._flush_pending(force_flush=True) @@ -1329,8 +1340,8 @@ async def _select_next_server(self) -> None: # Not yet exceeded max_reconnect_attempts so can still use # this server in the future. self._server_pool.append(s) - if s.last_attempt is not None and now < s.last_attempt + self.options[ - "reconnect_time_wait"]: + if (s.last_attempt is not None and now + < s.last_attempt + self.options["reconnect_time_wait"]): # Backoff connecting to server if we attempted recently. await asyncio.sleep(self.options["reconnect_time_wait"]) try: @@ -1347,13 +1358,13 @@ async def _select_next_server(self) -> None: s.uri, ssl_context=self.ssl_context, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options["connect_timeout"], ) else: await self._transport.connect( s.uri, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options["connect_timeout"], ) self._current_server = s break @@ -1409,8 +1420,8 @@ async def _process_op_err(self, e: Exception) -> None: self._status = Client.RECONNECTING self._ps.reset() - if self._reconnection_task is not None and not self._reconnection_task.cancelled( - ): + if (self._reconnection_task is not None + and not self._reconnection_task.cancelled()): # Cancel the previous task in case it may still be running. self._reconnection_task.cancel() @@ -1428,8 +1439,8 @@ async def _attempt_reconnect(self) -> None: ): self._reading_task.cancel() - if self._ping_interval_task is not None and not self._ping_interval_task.cancelled( - ): + if (self._ping_interval_task is not None + and not self._ping_interval_task.cancelled()): self._ping_interval_task.cancel() if self._flusher_task is not None and not self._flusher_task.cancelled( @@ -1525,24 +1536,24 @@ async def _attempt_reconnect(self) -> None: except asyncio.CancelledError: break - if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled( - ): + if (self._reconnection_task_future is not None + and not self._reconnection_task_future.cancelled()): self._reconnection_task_future.set_result(True) def _connect_command(self) -> bytes: - ''' + """ Generates a JSON string with the params to be used when sending CONNECT to the server. ->> CONNECT {"lang": "python3"} - ''' + """ options = { "verbose": self.options["verbose"], "pedantic": self.options["pedantic"], "lang": __lang__, "version": __version__, - "protocol": PROTOCOL + "protocol": PROTOCOL, } if "headers" in self._server_info: options["headers"] = self._server_info["headers"] @@ -1560,8 +1571,8 @@ def _connect_command(self) -> bytes: options["nkey"] = self._public_nkey # In case there is no password, then consider handle # sending a token instead. - elif self.options["user"] is not None and self.options[ - "password"] is not None: + elif (self.options["user"] is not None + and self.options["password"] is not None): options["user"] = self.options["user"] options["pass"] = self.options["password"] elif self.options["token"] is not None: @@ -1579,7 +1590,7 @@ def _connect_command(self) -> bytes: options["echo"] = not self.options["no_echo"] connect_opts = json.dumps(options, sort_keys=True) - return b''.join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_]) + return b"".join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_]) async def _process_ping(self) -> None: """ @@ -1706,8 +1717,8 @@ async def _process_msg( Process MSG sent by server. """ payload_size = len(data) - self.stats['in_msgs'] += 1 - self.stats['in_bytes'] += payload_size + self.stats["in_msgs"] += 1 + self.stats["in_bytes"] += payload_size sub = self._subs.get(sid) if not sub: @@ -1780,7 +1791,8 @@ async def _process_msg( try: sub._pending_size += payload_size # allow setting pending_bytes_limit to 0 to disable - if sub._pending_bytes_limit > 0 and sub._pending_size >= sub._pending_bytes_limit: + if (sub._pending_bytes_limit > 0 + and sub._pending_size >= sub._pending_bytes_limit): # Subtract the bytes since the message will be thrown away # so it would not be pending data. sub._pending_size -= payload_size @@ -1859,23 +1871,24 @@ def _process_info( with latest updates from cluster to enable server discovery. """ assert self._current_server, "Client.connect must be called first" - if 'connect_urls' in info: - if info['connect_urls']: + if "connect_urls" in info: + if info["connect_urls"]: connect_urls = [] - for connect_url in info['connect_urls']: - scheme = '' - if self._current_server.uri.scheme == 'tls': - scheme = 'tls' + for connect_url in info["connect_urls"]: + scheme = "" + if self._current_server.uri.scheme == "tls": + scheme = "tls" else: - scheme = 'nats' + scheme = "nats" uri = urlparse(f"{scheme}://{connect_url}") srv = Srv(uri) srv.discovered = True # Check whether we should reuse the original hostname. - if 'tls_required' in self._server_info and self._server_info['tls_required'] \ - and self._host_is_ip(uri.hostname): + if ("tls_required" in self._server_info + and self._server_info["tls_required"] + and self._host_is_ip(uri.hostname)): srv.tls_name = self._current_server.uri.hostname # Filter for any similar server in the server pool already. @@ -1891,7 +1904,8 @@ def _process_info( for srv in connect_urls: self._server_pool.append(srv) - if not initial_connection and connect_urls and self._discovered_server_cb: + if (not initial_connection and connect_urls + and self._discovered_server_cb): self._discovered_server_cb() def _host_is_ip(self, connect_url: Optional[str]) -> bool: @@ -1922,13 +1936,13 @@ async def _process_connect_init(self) -> None: else: hostname = self._current_server.uri.hostname - handshake_first = self.options['tls_handshake_first'] + handshake_first = self.options["tls_handshake_first"] if handshake_first: await self._transport.connect_tls( hostname, self.ssl_context, DEFAULT_BUFFER_SIZE, - self.options['connect_timeout'], + self.options["connect_timeout"], ) connection_completed = self._transport.readline() @@ -1955,17 +1969,18 @@ async def _process_connect_init(self) -> None: self._process_info(srv_info, initial_connection=True) - if 'version' in self._server_info: - self._current_server.server_version = self._server_info['version'] + if "version" in self._server_info: + self._current_server.server_version = self._server_info["version"] - if 'max_payload' in self._server_info: + if "max_payload" in self._server_info: self._max_payload = self._server_info["max_payload"] - if 'client_id' in self._server_info: + if "client_id" in self._server_info: self._client_id = self._server_info["client_id"] - if 'tls_required' in self._server_info and self._server_info[ - 'tls_required'] and self._current_server.uri.scheme != "ws": + if ("tls_required" in self._server_info + and self._server_info["tls_required"] + and self._current_server.uri.scheme != "ws"): if not handshake_first: await self._transport.drain() # just in case something is left @@ -1974,7 +1989,7 @@ async def _process_connect_init(self) -> None: hostname, self.ssl_context, DEFAULT_BUFFER_SIZE, - self.options['connect_timeout'], + self.options["connect_timeout"], ) # Refresh state of parser upon reconnect. @@ -2000,7 +2015,7 @@ async def _process_connect_init(self) -> None: # FIXME: Maybe handling could be more special here, # checking for errors.AuthorizationError for example. # await self._process_err(err_msg) - raise errors.Error("nats: " + err_msg.rstrip('\r\n')) + raise errors.Error("nats: " + err_msg.rstrip("\r\n")) self._transport.write(PING_PROTO) await self._transport.drain() @@ -2019,7 +2034,7 @@ async def _process_connect_init(self) -> None: # FIXME: Maybe handling could be more special here, # checking for ErrAuthorization for example. # await self._process_err(err_msg) - raise errors.Error("nats: " + err_msg.rstrip('\r\n')) + raise errors.Error("nats: " + err_msg.rstrip("\r\n")) if PONG_PROTO in next_op: self._status = Client.CONNECTED @@ -2124,7 +2139,7 @@ async def _read_loop(self) -> None: except asyncio.CancelledError: break except Exception as ex: - _logger.error('nats: encountered error', exc_info=ex) + _logger.error("nats: encountered error", exc_info=ex) break # except asyncio.InvalidStateError: # pass diff --git a/nats/aio/errors.py b/nats/aio/errors.py index d7cc7aca..66fd68d4 100644 --- a/nats/aio/errors.py +++ b/nats/aio/errors.py @@ -24,6 +24,7 @@ class NatsError(nats.errors.Error): Please use `nats.errors.Error` instead. """ + pass @@ -34,6 +35,7 @@ class ErrConnectionClosed(nats.errors.ConnectionClosedError): Please use `nats.errors.ConnectionClosedError` instead. """ + pass @@ -44,6 +46,7 @@ class ErrDrainTimeout(nats.errors.DrainTimeoutError): Please use `nats.errors.DrainTimeoutError` instead. """ + pass @@ -54,6 +57,7 @@ class ErrInvalidUserCredentials(nats.errors.InvalidUserCredentialsError): Please use `nats.errors.InvalidUserCredentialsError` instead. """ + pass @@ -64,6 +68,7 @@ class ErrInvalidCallbackType(nats.errors.InvalidCallbackTypeError): Please use `nats.errors.InvalidCallbackTypeError` instead. """ + pass @@ -74,6 +79,7 @@ class ErrConnectionReconnecting(nats.errors.ConnectionReconnectingError): Please use `nats.errors.ConnectionReconnectingError` instead. """ + pass @@ -84,6 +90,7 @@ class ErrConnectionDraining(nats.errors.ConnectionDrainingError): Please use `nats.errors.ConnectionDrainingError` instead. """ + pass @@ -94,6 +101,7 @@ class ErrMaxPayload(nats.errors.MaxPayloadError): Please use `nats.errors.MaxPayloadError` instead. """ + pass @@ -104,6 +112,7 @@ class ErrStaleConnection(nats.errors.StaleConnectionError): Please use `nats.errors.StaleConnectionError` instead. """ + pass @@ -114,6 +123,7 @@ class ErrJsonParse(nats.errors.JsonParseError): Please use `nats.errors.JsonParseError` instead. """ + pass @@ -124,6 +134,7 @@ class ErrSecureConnRequired(nats.errors.SecureConnRequiredError): Please use `nats.errors.SecureConnRequiredError` instead. """ + pass @@ -134,6 +145,7 @@ class ErrSecureConnWanted(nats.errors.SecureConnWantedError): Please use `nats.errors.SecureConnWantedError` instead. """ + pass @@ -144,6 +156,7 @@ class ErrSecureConnFailed(nats.errors.SecureConnFailedError): Please use `nats.errors.SecureConnFailedError` instead. """ + pass @@ -154,6 +167,7 @@ class ErrBadSubscription(nats.errors.BadSubscriptionError): Please use `nats.errors.BadSubscriptionError` instead. """ + pass @@ -164,6 +178,7 @@ class ErrBadSubject(nats.errors.BadSubjectError): Please use `nats.errors.BadSubjectError` instead. """ + pass @@ -174,6 +189,7 @@ class ErrSlowConsumer(nats.errors.SlowConsumerError): Please use `nats.errors.SlowConsumerError` instead. """ + pass @@ -184,6 +200,7 @@ class ErrTimeout(nats.errors.TimeoutError): Please use `nats.errors.TimeoutError` instead. """ + pass @@ -194,6 +211,7 @@ class ErrBadTimeout(nats.errors.BadTimeoutError): Please use `nats.errors.BadTimeoutError` instead. """ + pass @@ -204,6 +222,7 @@ class ErrAuthorization(nats.errors.AuthorizationError): Please use `nats.errors.AuthorizationError` instead. """ + pass @@ -214,4 +233,5 @@ class ErrNoServers(nats.errors.NoServersError): Please use `nats.errors.NoServersError` instead. """ + pass diff --git a/nats/aio/msg.py b/nats/aio/msg.py index 8eccd991..0905129f 100644 --- a/nats/aio/msg.py +++ b/nats/aio/msg.py @@ -40,10 +40,11 @@ class Msg: """ Msg represents a message delivered by NATS. """ + _client: NATS - subject: str = '' - reply: str = '' - data: bytes = b'' + subject: str = "" + reply: str = "" + data: bytes = b"" headers: Optional[Dict[str, str]] = None _metadata: Optional[Metadata] = None @@ -57,8 +58,8 @@ class Ack: Term = b"+TERM" # Reply metadata... - Prefix0 = '$JS' - Prefix1 = 'ACK' + Prefix0 = "$JS" + Prefix1 = "ACK" Domain = 2 AccHash = 3 Stream = 4 @@ -82,7 +83,7 @@ def sid(self) -> int: sid returns the subscription ID from a message. """ if self._sid is None: - raise Error('sid not set') + raise Error("sid not set") return self._sid async def respond(self, data: bytes) -> None: @@ -90,9 +91,9 @@ async def respond(self, data: bytes) -> None: respond replies to the inbox of the message if there is one. """ if not self.reply: - raise Error('no reply subject available') + raise Error("no reply subject available") if not self._client: - raise Error('client not set') + raise Error("client not set") await self._client.publish(self.reply, data, headers=self.headers) @@ -122,9 +123,9 @@ async def nak(self, delay: Union[int, float, None] = None) -> None: payload = Msg.Ack.Nak json_args = dict() if delay: - json_args['delay'] = int(delay * 10**9) # from seconds to ns + json_args["delay"] = int(delay * 10**9) # from seconds to ns if json_args: - payload += (b' ' + json.dumps(json_args).encode()) + payload += b" " + json.dumps(json_args).encode() await self._client.publish(self.reply, payload) self._ackd = True @@ -133,7 +134,7 @@ async def in_progress(self) -> None: in_progress acknowledges a message delivered by JetStream is still being worked on. Unlike other types of acks, an in-progress ack (+WPI) can be done multiple times. """ - if self.reply is None or self.reply == '': + if self.reply is None or self.reply == "": raise NotJSMessageError await self._client.publish(self.reply, Msg.Ack.Progress) @@ -165,7 +166,7 @@ def _get_metadata_fields(self, reply: Optional[str]) -> List[str]: return Msg.Metadata._get_metadata_fields(reply) def _check_reply(self) -> None: - if self.reply is None or self.reply == '': + if self.reply is None or self.reply == "": raise NotJSMessageError if self._ackd: raise MsgAlreadyAckdError(self) @@ -184,6 +185,7 @@ class Metadata: - consumer is the name of the consumer. """ + sequence: SequencePair num_pending: int num_delivered: int @@ -197,6 +199,7 @@ class SequencePair: """ SequencePair represents a pair of consumer and stream sequence. """ + consumer: int stream: int @@ -204,18 +207,17 @@ class SequencePair: def _get_metadata_fields(cls, reply: Optional[str]) -> List[str]: if not reply: raise NotJSMessageError - tokens = reply.split('.') - if (len(tokens) == _V1_TOKEN_COUNT or - len(tokens) >= _V2_TOKEN_COUNT-1) and \ - tokens[0] == Msg.Ack.Prefix0 and \ - tokens[1] == Msg.Ack.Prefix1: + tokens = reply.split(".") + if ((len(tokens) == _V1_TOKEN_COUNT + or len(tokens) >= _V2_TOKEN_COUNT - 1) + and tokens[0] == Msg.Ack.Prefix0 + and tokens[1] == Msg.Ack.Prefix1): return tokens raise NotJSMessageError @classmethod def _from_reply(cls, reply: str) -> Msg.Metadata: - """Construct the metadata from the reply string - """ + """Construct the metadata from the reply string""" tokens = cls._get_metadata_fields(reply) if len(tokens) == _V1_TOKEN_COUNT: t = datetime.datetime.fromtimestamp( diff --git a/nats/aio/subscription.py b/nats/aio/subscription.py index a322c68f..1238ca89 100644 --- a/nats/aio/subscription.py +++ b/nats/aio/subscription.py @@ -26,6 +26,7 @@ from uuid import uuid4 from nats import errors + # Default Pending Limits of Subscriptions from nats.aio.msg import Msg @@ -63,8 +64,8 @@ def __init__( self, conn, id: int = 0, - subject: str = '', - queue: str = '', + subject: str = "", + queue: str = "", cb: Optional[Callable[[Msg], Awaitable[None]]] = None, future: Optional[asyncio.Future] = None, max_msgs: int = 0, @@ -123,7 +124,7 @@ def messages(self) -> AsyncIterator[Msg]: subscription. :: - + nc = await nats.connect() sub = await nc.subscribe('foo') # Use `async for` which implicitly awaits messages @@ -177,7 +178,7 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg: if self._cb: raise errors.Error( - 'nats: next_msg cannot be used in async subscriptions' + "nats: next_msg cannot be used in async subscriptions" ) task_name = str(uuid4()) @@ -210,8 +211,9 @@ def _start(self, error_cb): Creates the resources for the subscription to start processing messages. """ if self._cb: - if not asyncio.iscoroutinefunction(self._cb) and \ - not (hasattr(self._cb, "func") and asyncio.iscoroutinefunction(self._cb.func)): + if not asyncio.iscoroutinefunction(self._cb) and not ( + hasattr(self._cb, "func") + and asyncio.iscoroutinefunction(self._cb.func)): raise errors.Error( "nats: must use coroutine for subscriptions" ) @@ -328,7 +330,8 @@ async def _wait_for_msgs(self, error_cb) -> None: self._pending_queue.task_done() # Apply auto unsubscribe checks after having processed last msg. - if self._max_msgs > 0 and self._received >= self._max_msgs and self._pending_queue.empty: + if (self._max_msgs > 0 and self._received >= self._max_msgs + and self._pending_queue.empty): self._stop_processing() except asyncio.CancelledError: break diff --git a/nats/aio/transport.py b/nats/aio/transport.py index 9bde8475..74a597ad 100644 --- a/nats/aio/transport.py +++ b/nats/aio/transport.py @@ -123,7 +123,8 @@ async def connect( host=uri.hostname, port=uri.port, limit=buffer_size, - ), connect_timeout + ), + connect_timeout, ) # We keep a reference to the initial transport we used when # establishing the connection in case we later upgrade to TLS @@ -142,7 +143,7 @@ async def connect_tls( buffer_size: int, connect_timeout: int, ) -> None: - assert self._io_writer, f'{type(self).__name__}.connect must be called first' + assert self._io_writer, f"{type(self).__name__}.connect must be called first" # manually recreate the stream reader/writer with a tls upgraded transport reader = asyncio.StreamReader() @@ -152,7 +153,7 @@ async def connect_tls( protocol, ssl_context, # hostname here will be passed directly as string - server_hostname=uri if isinstance(uri, str) else uri.hostname + server_hostname=uri if isinstance(uri, str) else uri.hostname, ) transport = await asyncio.wait_for(transport_future, connect_timeout) writer = asyncio.StreamWriter( @@ -167,7 +168,7 @@ def writelines(self, payload): return self._io_writer.writelines(payload) async def read(self, buffer_size: int): - assert self._io_reader, f'{type(self).__name__}.connect must be called first' + assert self._io_reader, f"{type(self).__name__}.connect must be called first" return await self._io_reader.read(buffer_size) async def readline(self): @@ -226,7 +227,7 @@ async def connect_tls( self._ws = await self._client.ws_connect( uri if isinstance(uri, str) else uri.geturl(), ssl=ssl_context, - timeout=connect_timeout + timeout=connect_timeout, ) self._using_tls = True @@ -244,7 +245,7 @@ async def readline(self): data = await self._ws.receive() if data.type == aiohttp.WSMsgType.CLOSED: # if the connection terminated abruptly, return empty binary data to raise unexpected EOF - return b'' + return b"" return data.data async def drain(self): diff --git a/nats/errors.py b/nats/errors.py index 953a4ca2..a5d29287 100644 --- a/nats/errors.py +++ b/nats/errors.py @@ -107,8 +107,10 @@ def __init__( self.sub = sub def __str__(self) -> str: - return "nats: slow consumer, messages dropped subject: " \ - f"{self.subject}, sid: {self.sid}, sub: {self.sub}" + return ( + "nats: slow consumer, messages dropped subject: " + f"{self.subject}, sid: {self.sid}, sub: {self.sub}" + ) class BadTimeoutError(Error): diff --git a/nats/js/api.py b/nats/js/api.py index 3205b024..43b2fd5f 100644 --- a/nats/js/api.py +++ b/nats/js/api.py @@ -36,7 +36,7 @@ class Header(str, Enum): DEFAULT_PREFIX = "$JS.API" -INBOX_PREFIX = b'_INBOX.' +INBOX_PREFIX = b"_INBOX." class StatusCode(str, Enum): @@ -58,8 +58,7 @@ class Base: @staticmethod def _convert(resp: Dict[str, Any], field: str, type: type[Base]) -> None: - """Convert the field into the given type in place. - """ + """Convert the field into the given type in place.""" data = resp.get(field, None) if data is None: resp[field] = None @@ -70,8 +69,7 @@ def _convert(resp: Dict[str, Any], field: str, type: type[Base]) -> None: @staticmethod def _convert_nanoseconds(resp: Dict[str, Any], field: str) -> None: - """Convert the given field from nanoseconds to seconds in place. - """ + """Convert the given field from nanoseconds to seconds in place.""" val = resp.get(field, None) if val is not None: val = val / _NANOSECOND @@ -79,8 +77,7 @@ def _convert_nanoseconds(resp: Dict[str, Any], field: str) -> None: @staticmethod def _to_nanoseconds(val: Optional[float]) -> Optional[int]: - """Convert the value from seconds to nanoseconds. - """ + """Convert the value from seconds to nanoseconds.""" if val is None: # We use 0 to avoid sending null to Go servers. return 0 @@ -99,13 +96,11 @@ def from_response(cls: type[_B], resp: Dict[str, Any]) -> _B: return cls(**params) def evolve(self: _B, **params) -> _B: - """Return a copy of the instance with the passed values replaced. - """ + """Return a copy of the instance with the passed values replaced.""" return replace(self, **params) def as_dict(self) -> Dict[str, object]: - """Return the object converted into an API-friendly dict. - """ + """Return the object converted into an API-friendly dict.""" result = {} for field in fields(self): val = getattr(self, field.name) @@ -125,6 +120,7 @@ class PubAck(Base): """ PubAck is the response of publishing a message to JetStream. """ + stream: str seq: int domain: Optional[str] = None @@ -161,14 +157,14 @@ class StreamSource(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'external', ExternalStream) - cls._convert(resp, 'subject_transforms', SubjectTransform) + cls._convert(resp, "external", ExternalStream) + cls._convert(resp, "subject_transforms", SubjectTransform) return super().from_response(resp) def as_dict(self) -> Dict[str, object]: result = super().as_dict() if self.subject_transforms: - result['subject_transform'] = [ + result["subject_transform"] = [ tr.as_dict() for tr in self.subject_transforms ] return result @@ -202,7 +198,7 @@ class StreamState(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'lost', LostStreamData) + cls._convert(resp, "lost", LostStreamData) return super().from_response(resp) @@ -247,6 +243,7 @@ class RePublish(Base): RePublish is for republishing messages once committed to a stream. The original subject cis remapped from the subject pattern to the destination pattern. """ + src: Optional[str] = None dest: Optional[str] = None headers_only: Optional[bool] = None @@ -255,6 +252,7 @@ class RePublish(Base): @dataclass class SubjectTransform(Base): """Subject transform to apply to matching messages.""" + src: str dest: str @@ -268,6 +266,7 @@ class StreamConfig(Base): """ StreamConfig represents the configuration of a stream. """ + name: Optional[str] = None description: Optional[str] = None subjects: Optional[List[str]] = None @@ -311,24 +310,25 @@ class StreamConfig(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert_nanoseconds(resp, 'max_age') - cls._convert_nanoseconds(resp, 'duplicate_window') - cls._convert(resp, 'placement', Placement) - cls._convert(resp, 'mirror', StreamSource) - cls._convert(resp, 'sources', StreamSource) - cls._convert(resp, 'republish', RePublish) - cls._convert(resp, 'subject_transform', SubjectTransform) + cls._convert_nanoseconds(resp, "max_age") + cls._convert_nanoseconds(resp, "duplicate_window") + cls._convert(resp, "placement", Placement) + cls._convert(resp, "mirror", StreamSource) + cls._convert(resp, "sources", StreamSource) + cls._convert(resp, "republish", RePublish) + cls._convert(resp, "subject_transform", SubjectTransform) return super().from_response(resp) def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['duplicate_window'] = self._to_nanoseconds( + result["duplicate_window"] = self._to_nanoseconds( self.duplicate_window ) - result['max_age'] = self._to_nanoseconds(self.max_age) + result["max_age"] = self._to_nanoseconds(self.max_age) if self.sources: - result['sources'] = [src.as_dict() for src in self.sources] - if self.compression and (self.compression != StoreCompression.NONE and self.compression != StoreCompression.S2): + result["sources"] = [src.as_dict() for src in self.sources] + if self.compression and (self.compression != StoreCompression.NONE + and self.compression != StoreCompression.S2): raise ValueError( "nats: invalid store compression type: %s" % self.compression ) @@ -354,7 +354,7 @@ class ClusterInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'replicas', PeerInfo) + cls._convert(resp, "replicas", PeerInfo) return super().from_response(resp) @@ -363,6 +363,7 @@ class StreamInfo(Base): """ StreamInfo is the latest information about a stream from JetStream. """ + config: StreamConfig state: StreamState mirror: Optional[StreamSourceInfo] = None @@ -372,11 +373,11 @@ class StreamInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'config', StreamConfig) - cls._convert(resp, 'state', StreamState) - cls._convert(resp, 'mirror', StreamSourceInfo) - cls._convert(resp, 'sources', StreamSourceInfo) - cls._convert(resp, 'cluster', ClusterInfo) + cls._convert(resp, "config", StreamConfig) + cls._convert(resp, "state", StreamState) + cls._convert(resp, "mirror", StreamSourceInfo) + cls._convert(resp, "sources", StreamSourceInfo) + cls._convert(resp, "cluster", ClusterInfo) return super().from_response(resp) @@ -385,7 +386,10 @@ class StreamsListIterator(Iterable): """ StreamsListIterator is an iterator for streams list responses from JetStream. """ - def __init__(self, offset: int, total: int, streams: List[Dict[str, any]]) -> None: + + def __init__( + self, offset: int, total: int, streams: List[Dict[str, any]] + ) -> None: self.offset = offset self.total = total self.streams = streams @@ -456,6 +460,7 @@ class ConsumerConfig(Base): References: * `Consumers `_ """ + name: Optional[str] = None durable_name: Optional[str] = None description: Optional[str] = None @@ -497,22 +502,22 @@ class ConsumerConfig(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert_nanoseconds(resp, 'ack_wait') - cls._convert_nanoseconds(resp, 'idle_heartbeat') - cls._convert_nanoseconds(resp, 'inactive_threshold') - if 'backoff' in resp: - resp['backoff'] = [val / _NANOSECOND for val in resp['backoff']] + cls._convert_nanoseconds(resp, "ack_wait") + cls._convert_nanoseconds(resp, "idle_heartbeat") + cls._convert_nanoseconds(resp, "inactive_threshold") + if "backoff" in resp: + resp["backoff"] = [val / _NANOSECOND for val in resp["backoff"]] return super().from_response(resp) def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['ack_wait'] = self._to_nanoseconds(self.ack_wait) - result['idle_heartbeat'] = self._to_nanoseconds(self.idle_heartbeat) - result['inactive_threshold'] = self._to_nanoseconds( + result["ack_wait"] = self._to_nanoseconds(self.ack_wait) + result["idle_heartbeat"] = self._to_nanoseconds(self.idle_heartbeat) + result["inactive_threshold"] = self._to_nanoseconds( self.inactive_threshold ) if self.backoff: - result['backoff'] = [self._to_nanoseconds(i) for i in self.backoff] + result["backoff"] = [self._to_nanoseconds(i) for i in self.backoff] return result @@ -529,6 +534,7 @@ class ConsumerInfo(Base): """ ConsumerInfo represents the info about the consumer. """ + name: str stream_name: str config: ConsumerConfig @@ -545,10 +551,10 @@ class ConsumerInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'delivered', SequenceInfo) - cls._convert(resp, 'ack_floor', SequenceInfo) - cls._convert(resp, 'config', ConsumerConfig) - cls._convert(resp, 'cluster', ClusterInfo) + cls._convert(resp, "delivered", SequenceInfo) + cls._convert(resp, "ack_floor", SequenceInfo) + cls._convert(resp, "config", ConsumerConfig) + cls._convert(resp, "cluster", ClusterInfo) return super().from_response(resp) @@ -580,7 +586,7 @@ class Tier(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'limits', AccountLimits) + cls._convert(resp, "limits", AccountLimits) return super().from_response(resp) @@ -599,6 +605,7 @@ class AccountInfo(Base): References: * `Account Information `_ """ + # NOTE: These fields are shared with Tier type as well. memory: int storage: int @@ -612,10 +619,10 @@ class AccountInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'limits', AccountLimits) - cls._convert(resp, 'api', APIStats) + cls._convert(resp, "limits", AccountLimits) + cls._convert(resp, "api", APIStats) info = super().from_response(resp) - tiers = resp.get('tiers', None) + tiers = resp.get("tiers", None) if tiers: result = {} for k, v in tiers.items(): @@ -651,6 +658,7 @@ class KeyValueConfig(Base): """ KeyValueConfig is the configuration of a KeyValue store. """ + bucket: str description: Optional[str] = None max_value_size: Optional[int] = None @@ -665,7 +673,7 @@ class KeyValueConfig(Base): def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['ttl'] = self._to_nanoseconds(self.ttl) + result["ttl"] = self._to_nanoseconds(self.ttl) return result @@ -674,6 +682,7 @@ class StreamPurgeRequest(Base): """ StreamPurgeRequest is optional request information to the purge API. """ + # Purge up to but not including sequence. seq: Optional[int] = None # Subject to match against messages for the purge command. @@ -687,6 +696,7 @@ class ObjectStoreConfig(Base): """ ObjectStoreConfig is the configurigation of an ObjectStore. """ + bucket: Optional[str] = None description: Optional[str] = None ttl: Optional[float] = None @@ -697,7 +707,7 @@ class ObjectStoreConfig(Base): def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['ttl'] = self._to_nanoseconds(self.ttl) + result["ttl"] = self._to_nanoseconds(self.ttl) return result @@ -706,6 +716,7 @@ class ObjectLink(Base): """ ObjectLink is used to embed links to other buckets and objects. """ + # Bucket is the name of the other object store. bucket: str # Name can be used to link to a single object. @@ -724,7 +735,7 @@ class ObjectMetaOptions(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'link', ObjectLink) + cls._convert(resp, "link", ObjectLink) return super().from_response(resp) @@ -733,6 +744,7 @@ class ObjectMeta(Base): """ ObjectMeta is high level information about an object. """ + name: Optional[str] = None description: Optional[str] = None headers: Optional[dict] = None @@ -741,7 +753,7 @@ class ObjectMeta(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'options', ObjectMetaOptions) + cls._convert(resp, "options", ObjectMetaOptions) return super().from_response(resp) @@ -750,6 +762,7 @@ class ObjectInfo(Base): """ ObjectInfo is meta plus instance information. """ + name: str bucket: str nuid: str @@ -779,5 +792,5 @@ def is_link(self) -> bool: @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'options', ObjectMetaOptions) + cls._convert(resp, "options", ObjectMetaOptions) return super().from_response(resp) diff --git a/nats/js/client.py b/nats/js/client.py index 0ee13654..bc951be1 100644 --- a/nats/js/client.py +++ b/nats/js/client.py @@ -26,12 +26,21 @@ from nats.aio.msg import Msg from nats.aio.subscription import Subscription from nats.js import api -from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError, FetchTimeoutError +from nats.js.errors import ( + BadBucketError, + BucketNotFoundError, + InvalidBucketNameError, + NotFoundError, + FetchTimeoutError, +) from nats.js.kv import KeyValue from nats.js.manager import JetStreamManager from nats.js.object_store import ( - VALID_BUCKET_RE, OBJ_ALL_CHUNKS_PRE_TEMPLATE, OBJ_ALL_META_PRE_TEMPLATE, - OBJ_STREAM_TEMPLATE, ObjectStore + VALID_BUCKET_RE, + OBJ_ALL_CHUNKS_PRE_TEMPLATE, + OBJ_ALL_META_PRE_TEMPLATE, + OBJ_STREAM_TEMPLATE, + ObjectStore, ) if TYPE_CHECKING: @@ -39,13 +48,13 @@ NO_RESPONDERS_STATUS = "503" -NATS_HDR_LINE = bytearray(b'NATS/1.0') +NATS_HDR_LINE = bytearray(b"NATS/1.0") NATS_HDR_LINE_SIZE = len(NATS_HDR_LINE) -_CRLF_ = b'\r\n' +_CRLF_ = b"\r\n" _CRLF_LEN_ = len(_CRLF_) KV_STREAM_TEMPLATE = "KV_{bucket}" KV_PRE_TEMPLATE = "$KV.{bucket}." -Callback = Callable[['Msg'], Awaitable[None]] +Callback = Callable[["Msg"], Awaitable[None]] # For JetStream the default pending limits are larger. DEFAULT_JS_SUB_PENDING_MSGS_LIMIT = 512 * 1024 @@ -106,7 +115,9 @@ def __init__( self._publish_async_completed_event = asyncio.Event() self._publish_async_completed_event.set() - self._publish_async_pending_semaphore = asyncio.Semaphore(publish_async_max_pending) + self._publish_async_pending_semaphore = asyncio.Semaphore( + publish_async_max_pending + ) @property def _jsm(self) -> JetStreamManager: @@ -120,12 +131,12 @@ async def _init_async_reply(self) -> None: self._publish_async_futures = {} self._async_reply_prefix = self._nc._inbox_prefix[:] - self._async_reply_prefix.extend(b'.') + self._async_reply_prefix.extend(b".") self._async_reply_prefix.extend(self._nc._nuid.next()) - self._async_reply_prefix.extend(b'.') + self._async_reply_prefix.extend(b".") async_reply_subject = self._async_reply_prefix[:] - async_reply_subject.extend(b'*') + async_reply_subject.extend(b"*") await self._nc.subscribe( async_reply_subject.decode(), cb=self._handle_async_reply @@ -142,15 +153,16 @@ async def _handle_async_reply(self, msg: Msg) -> None: return # Handle no responders - if msg.headers and msg.headers.get(api.Header.STATUS) == NO_RESPONDERS_STATUS: + if msg.headers and msg.headers.get(api.Header.STATUS + ) == NO_RESPONDERS_STATUS: future.set_exception(nats.js.errors.NoStreamResponseError) return # Handle response errors try: resp = json.loads(msg.data) - if 'error' in resp: - err = nats.js.errors.APIError.from_error(resp['error']) + if "error" in resp: + err = nats.js.errors.APIError.from_error(resp["error"]) future.set_exception(err) return @@ -162,10 +174,10 @@ async def _handle_async_reply(self, msg: Msg) -> None: async def publish( self, subject: str, - payload: bytes = b'', + payload: bytes = b"", timeout: Optional[float] = None, stream: Optional[str] = None, - headers: Optional[Dict[str, Any]] = None + headers: Optional[Dict[str, Any]] = None, ) -> api.PubAck: """ publish emits a new message to JetStream and waits for acknowledgement. @@ -188,14 +200,14 @@ async def publish( raise nats.js.errors.NoStreamResponseError resp = json.loads(msg.data) - if 'error' in resp: - raise nats.js.errors.APIError.from_error(resp['error']) + if "error" in resp: + raise nats.js.errors.APIError.from_error(resp["error"]) return api.PubAck.from_response(resp) async def publish_async( self, subject: str, - payload: bytes = b'', + payload: bytes = b"", wait_stall: Optional[float] = None, stream: Optional[str] = None, headers: Optional[Dict] = None, @@ -214,7 +226,10 @@ async def publish_async( hdr[api.Header.EXPECTED_STREAM] = stream try: - await asyncio.wait_for(self._publish_async_pending_semaphore.acquire(), timeout=wait_stall) + await asyncio.wait_for( + self._publish_async_pending_semaphore.acquire(), + timeout=wait_stall + ) except (asyncio.TimeoutError, asyncio.CancelledError): raise nats.js.errors.TooManyStalledMsgsError @@ -453,8 +468,7 @@ async def subscribe_bind( pending_msgs_limit: int = DEFAULT_JS_SUB_PENDING_MSGS_LIMIT, pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, ) -> PushSubscription: - """Push-subscribe to an existing consumer. - """ + """Push-subscribe to an existing consumer.""" # By default, async subscribers wrap the original callback and # auto ack the messages as they are delivered. # @@ -627,7 +641,7 @@ async def main(): sub = await self._nc.subscribe( deliver.decode(), pending_msgs_limit=pending_msgs_limit, - pending_bytes_limit=pending_bytes_limit + pending_bytes_limit=pending_bytes_limit, ) consumer_name = None # In nats.py v2.7.0 changing the first arg to be 'consumer' instead of 'durable', @@ -664,8 +678,9 @@ def _is_processable_msg(cls, status: Optional[str], msg: Msg) -> bool: @classmethod def _is_temporary_error(cls, status: Optional[str]) -> bool: - if status == api.StatusCode.NO_MESSAGES or status == api.StatusCode.CONFLICT \ - or status == api.StatusCode.REQUEST_TIMEOUT: + if (status == api.StatusCode.NO_MESSAGES + or status == api.StatusCode.CONFLICT + or status == api.StatusCode.REQUEST_TIMEOUT): return True else: return False @@ -684,7 +699,7 @@ def _time_until(cls, timeout: Optional[float], return None return timeout - (time.monotonic() - start_time) - class _JSI(): + class _JSI: def __init__( self, @@ -760,7 +775,8 @@ async def check_flow_control_response(self): if self._conn.is_closed: break - if (self._fciseq - self._psub._pending_queue.qsize()) >= self._fcd: + if (self._fciseq - + self._psub._pending_queue.qsize()) >= self._fcd: fc_reply = self._fcr try: if fc_reply: @@ -978,7 +994,7 @@ def __init__( self._stream = stream self._consumer = consumer prefix = self._js._prefix - self._nms = f'{prefix}.CONSUMER.MSG.NEXT.{stream}.{consumer}' + self._nms = f"{prefix}.CONSUMER.MSG.NEXT.{stream}.{consumer}" self._deliver = deliver.decode() @property @@ -1027,7 +1043,7 @@ async def fetch( self, batch: int = 1, timeout: Optional[float] = 5, - heartbeat: Optional[float] = None + heartbeat: Optional[float] = None, ) -> List[Msg]: """ fetch makes a request to JetStream to be delivered a set of messages. @@ -1079,7 +1095,7 @@ async def _fetch_one( self, expires: Optional[int], timeout: Optional[float], - heartbeat: Optional[float] = None + heartbeat: Optional[float] = None, ) -> Msg: queue = self._sub._pending_queue @@ -1100,11 +1116,11 @@ async def _fetch_one( # Make lingering request with expiration and wait for response. next_req = {} - next_req['batch'] = 1 + next_req["batch"] = 1 if expires: - next_req['expires'] = int(expires) + next_req["expires"] = int(expires) if heartbeat: - next_req['idle_heartbeat'] = int( + next_req["idle_heartbeat"] = int( heartbeat * 1_000_000_000 ) # to nanoseconds @@ -1157,7 +1173,7 @@ async def _fetch_n( batch: int, expires: Optional[int], timeout: Optional[float], - heartbeat: Optional[float] = None + heartbeat: Optional[float] = None, ) -> List[Msg]: msgs = [] queue = self._sub._pending_queue @@ -1185,14 +1201,14 @@ async def _fetch_n( # First request: Use no_wait to synchronously get as many available # based on the batch size until server sends 'No Messages' status msg. next_req = {} - next_req['batch'] = needed + next_req["batch"] = needed if expires: - next_req['expires'] = expires + next_req["expires"] = expires if heartbeat: - next_req['idle_heartbeat'] = int( + next_req["idle_heartbeat"] = int( heartbeat * 1_000_000_000 ) # to nanoseconds - next_req['no_wait'] = True + next_req["no_wait"] = True await self._nc.publish( self._nms, json.dumps(next_req).encode(), @@ -1251,11 +1267,11 @@ async def _fetch_n( # Second request: lingering request that will block until new messages # are made available and delivered to the client. next_req = {} - next_req['batch'] = needed + next_req["batch"] = needed if expires: - next_req['expires'] = expires + next_req["expires"] = expires if heartbeat: - next_req['idle_heartbeat'] = int( + next_req["idle_heartbeat"] = int( heartbeat * 1_000_000_000 ) # to nanoseconds @@ -1360,7 +1376,7 @@ async def key_value(self, bucket: str) -> KeyValue: stream=stream, pre=KV_PRE_TEMPLATE.format(bucket=bucket), js=self, - direct=bool(si.config.allow_direct) + direct=bool(si.config.allow_direct), ) async def create_key_value( diff --git a/nats/js/errors.py b/nats/js/errors.py index de22fc25..03951bfe 100644 --- a/nats/js/errors.py +++ b/nats/js/errors.py @@ -33,7 +33,7 @@ def __init__(self, description: Optional[str] = None) -> None: self.description = description def __str__(self) -> str: - desc = '' + desc = "" if self.description: desc = self.description return f"nats: JetStream.{self.__class__.__name__} {desc}" @@ -44,6 +44,7 @@ class APIError(Error): """ An Error that is the result of interacting with NATS JetStream. """ + code: Optional[int] err_code: Optional[int] description: Optional[str] @@ -77,7 +78,7 @@ def from_msg(cls, msg: Msg) -> NoReturn: @classmethod def from_error(cls, err: Dict[str, Any]): - code = err['code'] + code = err["code"] if code == 503: raise ServiceUnavailableError(**err) elif code == 500: @@ -100,6 +101,7 @@ class ServiceUnavailableError(APIError): """ A 503 error """ + pass @@ -107,6 +109,7 @@ class ServerError(APIError): """ A 500 error """ + pass @@ -114,6 +117,7 @@ class NotFoundError(APIError): """ A 404 error """ + pass @@ -121,6 +125,7 @@ class BadRequestError(APIError): """ A 400 error. """ + pass @@ -161,7 +166,7 @@ def __init__( self, stream_resume_sequence=None, consumer_sequence=None, - last_consumer_sequence=None + last_consumer_sequence=None, ) -> None: self.stream_resume_sequence = stream_resume_sequence self.consumer_sequence = consumer_sequence @@ -179,6 +184,7 @@ class BucketNotFoundError(NotFoundError): """ When attempted to bind to a JetStream KeyValue that does not exist. """ + pass @@ -190,6 +196,7 @@ class KeyValueError(APIError): """ Raised when there is an issue interacting with the KeyValue store. """ + pass @@ -251,6 +258,7 @@ class InvalidBucketNameError(Error): """ Raised when trying to create a KV or OBJ bucket with invalid name. """ + pass @@ -258,6 +266,7 @@ class InvalidObjectNameError(Error): """ Raised when trying to put an object in Object Store with invalid key. """ + pass @@ -265,6 +274,7 @@ class BadObjectMetaError(Error): """ Raised when trying to read corrupted metadata from Object Store. """ + pass @@ -272,6 +282,7 @@ class LinkIsABucketError(Error): """ Raised when trying to get object from Object Store that is a bucket. """ + pass @@ -279,6 +290,7 @@ class DigestMismatchError(Error): """ Raised when getting an object from Object Store that has a different digest than expected. """ + pass @@ -286,6 +298,7 @@ class ObjectNotFoundError(NotFoundError): """ When attempted to lookup an Object that does not exist. """ + pass @@ -293,6 +306,7 @@ class ObjectDeletedError(NotFoundError): """ When attempted to do an operation to an Object that does not exist. """ + pass @@ -300,4 +314,5 @@ class ObjectAlreadyExists(Error): """ When attempted to do an operation to an Object that already exist. """ + pass diff --git a/nats/js/kv.py b/nats/js/kv.py index 9ab077d1..071205aa 100644 --- a/nats/js/kv.py +++ b/nats/js/kv.py @@ -72,6 +72,7 @@ class Entry: """ An entry from a KeyValue store in JetStream. """ + bucket: str key: str value: Optional[bytes] @@ -85,6 +86,7 @@ class BucketStatus: """ BucketStatus is the status of a KeyValue bucket. """ + stream_info: api.StreamInfo bucket: str @@ -301,7 +303,8 @@ class KeyWatcher: def __init__(self, js): self._js = js - self._updates: asyncio.Queue[KeyValue.Entry | None] = asyncio.Queue(maxsize=256) + self._updates: asyncio.Queue[KeyValue.Entry + | None] = asyncio.Queue(maxsize=256) self._sub = None self._pending: Optional[int] = None @@ -357,7 +360,9 @@ async def keys(self, filters: List[str] = None, **kwargs) -> List[str]: if consumer_info and filters: # If NATS server < 2.10, filters might be ignored. if consumer_info.config.filter_subject != ">": - logger.warning("Server may ignore filters if version is < 2.10.") + logger.warning( + "Server may ignore filters if version is < 2.10." + ) except Exception as e: raise e @@ -432,7 +437,7 @@ async def watch_updates(msg): # keys() uses this if ignore_deletes: - if (op == KV_PURGE or op == KV_DEL): + if op == KV_PURGE or op == KV_DEL: if meta.num_pending == 0 and not watcher._init_done: await watcher._updates.put(None) watcher._init_done = True diff --git a/nats/js/manager.py b/nats/js/manager.py index 98648655..fa0946df 100644 --- a/nats/js/manager.py +++ b/nats/js/manager.py @@ -26,9 +26,9 @@ if TYPE_CHECKING: from nats import NATS -NATS_HDR_LINE = bytearray(b'NATS/1.0') +NATS_HDR_LINE = bytearray(b"NATS/1.0") NATS_HDR_LINE_SIZE = len(NATS_HDR_LINE) -_CRLF_ = b'\r\n' +_CRLF_ = b"\r\n" _CRLF_LEN_ = len(_CRLF_) @@ -50,7 +50,7 @@ def __init__( async def account_info(self) -> api.AccountInfo: resp = await self._api_request( - f"{self._prefix}.INFO", b'', timeout=self._timeout + f"{self._prefix}.INFO", b"", timeout=self._timeout ) return api.AccountInfo.from_response(resp) @@ -64,19 +64,25 @@ async def find_stream_name_by_subject(self, subject: str) -> str: info = await self._api_request( req_sub, req_data.encode(), timeout=self._timeout ) - if not info['streams']: + if not info["streams"]: raise NotFoundError - return info['streams'][0] + return info["streams"][0] - async def stream_info(self, name: str, subjects_filter: Optional[str] = None) -> api.StreamInfo: + async def stream_info( + self, + name: str, + subjects_filter: Optional[str] = None + ) -> api.StreamInfo: """ Get the latest StreamInfo by stream name. """ - req_data = '' + req_data = "" if subjects_filter: req_data = json.dumps({"subjects_filter": subjects_filter}) resp = await self._api_request( - f"{self._prefix}.STREAM.INFO.{name}", req_data.encode(), timeout=self._timeout + f"{self._prefix}.STREAM.INFO.{name}", + req_data.encode(), + timeout=self._timeout, ) return api.StreamInfo.from_response(resp) @@ -145,25 +151,25 @@ async def delete_stream(self, name: str) -> bool: resp = await self._api_request( f"{self._prefix}.STREAM.DELETE.{name}", timeout=self._timeout ) - return resp['success'] + return resp["success"] async def purge_stream( self, name: str, seq: Optional[int] = None, subject: Optional[str] = None, - keep: Optional[int] = None + keep: Optional[int] = None, ) -> bool: """ Purge a stream by name. """ stream_req: Dict[str, Any] = {} if seq: - stream_req['seq'] = seq + stream_req["seq"] = seq if subject: - stream_req['filter'] = subject + stream_req["filter"] = subject if keep: - stream_req['keep'] = keep + stream_req["keep"] = keep req = json.dumps(stream_req) resp = await self._api_request( @@ -171,7 +177,7 @@ async def purge_stream( req.encode(), timeout=self._timeout ) - return resp['success'] + return resp["success"] async def consumer_info( self, stream: str, consumer: str, timeout: Optional[float] = None @@ -181,7 +187,7 @@ async def consumer_info( timeout = self._timeout resp = await self._api_request( f"{self._prefix}.CONSUMER.INFO.{stream}.{consumer}", - b'', + b"", timeout=timeout ) return api.ConsumerInfo.from_response(resp) @@ -192,26 +198,33 @@ async def streams_info(self, offset=0) -> List[api.StreamInfo]: """ resp = await self._api_request( f"{self._prefix}.STREAM.LIST", - json.dumps({"offset": offset}).encode(), + json.dumps({ + "offset": offset + }).encode(), timeout=self._timeout, ) streams = [] - for stream in resp['streams']: + for stream in resp["streams"]: stream_info = api.StreamInfo.from_response(stream) streams.append(stream_info) return streams - async def streams_info_iterator(self, offset=0) -> Iterable[api.StreamInfo]: + async def streams_info_iterator(self, + offset=0) -> Iterable[api.StreamInfo]: """ streams_info retrieves a list of streams Iterator. """ resp = await self._api_request( f"{self._prefix}.STREAM.LIST", - json.dumps({"offset": offset}).encode(), + json.dumps({ + "offset": offset + }).encode(), timeout=self._timeout, ) - return api.StreamsListIterator(resp["offset"], resp["total"], resp["streams"]) + return api.StreamsListIterator( + resp["offset"], resp["total"], resp["streams"] + ) async def add_consumer( self, @@ -230,7 +243,7 @@ async def add_consumer( req_data = json.dumps(req).encode() resp = None - subject = '' + subject = "" version = self._nc.connected_server_version consumer_name_supported = version.major >= 2 and version.minor >= 9 if consumer_name_supported and config.name: @@ -252,10 +265,10 @@ async def add_consumer( async def delete_consumer(self, stream: str, consumer: str) -> bool: resp = await self._api_request( f"{self._prefix}.CONSUMER.DELETE.{stream}.{consumer}", - b'', - timeout=self._timeout + b"", + timeout=self._timeout, ) - return resp['success'] + return resp["success"] async def consumers_info( self, @@ -270,13 +283,13 @@ async def consumers_info( """ resp = await self._api_request( f"{self._prefix}.CONSUMER.LIST.{stream}", - b'' if offset is None else json.dumps({ + b"" if offset is None else json.dumps({ "offset": offset }).encode(), timeout=self._timeout, ) consumers = [] - for consumer in resp['consumers']: + for consumer in resp["consumers"]: consumer_info = api.ConsumerInfo.from_response(consumer) consumers.append(consumer_info) return consumers @@ -295,23 +308,23 @@ async def get_msg( req_subject = None req: Dict[str, Any] = {} if seq: - req['seq'] = seq + req["seq"] = seq if subject: - req['seq'] = None - req.pop('seq', None) - req['last_by_subj'] = subject + req["seq"] = None + req.pop("seq", None) + req["last_by_subj"] = subject if next: - req['seq'] = seq - req['last_by_subj'] = None - req.pop('last_by_subj', None) - req['next_by_subj'] = subject + req["seq"] = seq + req["last_by_subj"] = None + req.pop("last_by_subj", None) + req["next_by_subj"] = subject data = json.dumps(req) if direct: # $JS.API.DIRECT.GET.KV_{stream_name}.$KV.TEST.{key} if subject and (seq is None): # last_by_subject type request requires no payload. - data = '' + data = "" req_subject = f"{self._prefix}.DIRECT.GET.{stream_name}.{subject}" else: req_subject = f"{self._prefix}.DIRECT.GET.{stream_name}" @@ -328,7 +341,7 @@ async def get_msg( req_subject, data.encode(), timeout=self._timeout ) - raw_msg = api.RawStreamMsg.from_response(resp_data['message']) + raw_msg = api.RawStreamMsg.from_response(resp_data["message"]) if raw_msg.hdrs: hdrs = base64.b64decode(raw_msg.hdrs) raw_headers = hdrs[NATS_HDR_LINE_SIZE + _CRLF_LEN_:] @@ -351,18 +364,18 @@ async def get_msg( def _lift_msg_to_raw_msg(self, msg) -> api.RawStreamMsg: if not msg.data: msg.data = None - status = msg.headers.get('Status') + status = msg.headers.get("Status") if status: - if status == '404': + if status == "404": raise NotFoundError else: raise APIError.from_msg(msg) raw_msg = api.RawStreamMsg() - subject = msg.headers['Nats-Subject'] + subject = msg.headers["Nats-Subject"] raw_msg.subject = subject - seq = msg.headers.get('Nats-Sequence') + seq = msg.headers.get("Nats-Sequence") if seq: raw_msg.seq = int(seq) raw_msg.data = msg.data @@ -375,10 +388,10 @@ async def delete_msg(self, stream_name: str, seq: int) -> bool: delete_msg retrieves a message from a stream based on the sequence ID. """ req_subject = f"{self._prefix}.STREAM.MSG.DELETE.{stream_name}" - req = {'seq': seq} + req = {"seq": seq} data = json.dumps(req) resp = await self._api_request(req_subject, data.encode()) - return resp['success'] + return resp["success"] async def get_last_msg( self, @@ -394,7 +407,7 @@ async def get_last_msg( async def _api_request( self, req_subject: str, - req: bytes = b'', + req: bytes = b"", timeout: float = 5, ) -> Dict[str, Any]: try: @@ -404,7 +417,7 @@ async def _api_request( raise ServiceUnavailableError # Check for API errors. - if 'error' in resp: - raise APIError.from_error(resp['error']) + if "error" in resp: + raise APIError.from_error(resp["error"]) return resp diff --git a/nats/js/object_store.py b/nats/js/object_store.py index 5eb10612..2446cea4 100644 --- a/nats/js/object_store.py +++ b/nats/js/object_store.py @@ -25,8 +25,13 @@ import nats.errors from nats.js import api from nats.js.errors import ( - BadObjectMetaError, DigestMismatchError, ObjectAlreadyExists, - ObjectDeletedError, ObjectNotFoundError, NotFoundError, LinkIsABucketError + BadObjectMetaError, + DigestMismatchError, + ObjectAlreadyExists, + ObjectDeletedError, + ObjectNotFoundError, + NotFoundError, + LinkIsABucketError, ) from nats.js.kv import MSG_ROLLUP_SUBJECT @@ -71,6 +76,7 @@ class ObjectStoreStatus: """ ObjectStoreStatus is the status of a ObjectStore bucket. """ + stream_info: api.StreamInfo bucket: str @@ -140,7 +146,7 @@ async def get_info( meta = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode(), ) stream = OBJ_STREAM_TEMPLATE.format(bucket=self._name) @@ -209,7 +215,7 @@ async def get( executor_fn = None if writeinto: executor = asyncio.get_running_loop().run_in_executor - if hasattr(writeinto, 'buffer'): + if hasattr(writeinto, "buffer"): executor_fn = writeinto.buffer.write else: executor_fn = writeinto.write @@ -281,10 +287,10 @@ async def put( data = io.BytesIO(data.encode()) elif isinstance(data, bytes): data = io.BytesIO(data) - elif hasattr(data, 'readinto') or isinstance(data, io.BufferedIOBase): + elif hasattr(data, "readinto") or isinstance(data, io.BufferedIOBase): # Need to delegate to a threaded executor to avoid blocking. executor = asyncio.get_running_loop().run_in_executor - elif hasattr(data, 'buffer') or isinstance(data, io.TextIOWrapper): + elif hasattr(data, "buffer") or isinstance(data, io.TextIOWrapper): data = data.buffer else: raise TypeError("nats: invalid type for object store") @@ -298,7 +304,7 @@ async def put( nuid=newnuid.decode(), size=0, chunks=0, - mtime=datetime.now(timezone.utc).isoformat() + mtime=datetime.now(timezone.utc).isoformat(), ) h = sha256() chunk = bytearray(meta.options.max_chunk_size) @@ -334,14 +340,14 @@ async def put( # Prepare the meta message. meta_subj = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode(), ) # Publish the meta message. try: await self._js.publish( meta_subj, json.dumps(info.as_dict()).encode(), - headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT} + headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT}, ) except Exception as err: await self._js.purge_stream(self._stream, subject=chunk_subj) @@ -403,14 +409,14 @@ async def update_meta( # Prepare the meta message. meta_subj = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(name, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(name, "utf-8")).decode(), ) # Publish the meta message. try: await self._js.publish( meta_subj, json.dumps(info.as_dict()).encode(), - headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT} + headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT}, ) except Exception as err: raise err @@ -523,20 +529,20 @@ async def delete(self, name: str) -> ObjectResult: info.deleted = True info.size = 0 info.chunks = 0 - info.digest = '' - info.mtime = '' + info.digest = "" + info.mtime = "" # Prepare the meta message. meta_subj = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode(), ) # Publish the meta message. try: await self._js.publish( meta_subj, json.dumps(info.as_dict()).encode(), - headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT} + headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT}, ) finally: await self._js.purge_stream(self._stream, subject=chunk_subj) diff --git a/nats/micro/request.py b/nats/micro/request.py index e22cfa5c..1f5a05dc 100644 --- a/nats/micro/request.py +++ b/nats/micro/request.py @@ -47,7 +47,9 @@ def data(self) -> bytes: return self._msg.data async def respond( - self, data: bytes = b"", headers: Optional[Dict[str, str]] = None + self, + data: bytes = b"", + headers: Optional[Dict[str, str]] = None ) -> None: """Send a response to the request. @@ -82,12 +84,10 @@ async def respond_error( else: headers = {} - headers.update( - { - ERROR_HEADER: description, - ERROR_CODE_HEADER: code, - } - ) + headers.update({ + ERROR_HEADER: description, + ERROR_CODE_HEADER: code, + }) await self.respond(data, headers=headers) @@ -109,4 +109,6 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> ServiceError: - return cls(code=data.get("code", ""), description=data.get("description", "")) + return cls( + code=data.get("code", ""), description=data.get("description", "") + ) diff --git a/nats/micro/service.py b/nats/micro/service.py index aa258a45..c6acfaa5 100644 --- a/nats/micro/service.py +++ b/nats/micro/service.py @@ -25,7 +25,6 @@ from .request import Request, Handler, ServiceError - DEFAULT_QUEUE_GROUP = "q" """Queue Group name used across all services.""" @@ -274,7 +273,9 @@ async def _handle_request(self, msg: Msg) -> None: elapsed_time = current_time - start_time self._processing_time += elapsed_time - self._average_processing_time = int(self._processing_time / self._num_requests) + self._average_processing_time = int( + self._processing_time / self._num_requests + ) @dataclass @@ -294,7 +295,8 @@ class EndpointManager(Protocol): """ @overload - async def add_endpoint(self, config: EndpointConfig) -> None: ... + async def add_endpoint(self, config: EndpointConfig) -> None: + ... @overload async def add_endpoint( @@ -305,11 +307,13 @@ async def add_endpoint( queue_group: Optional[str] = None, subject: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, - ) -> None: ... + ) -> None: + ... async def add_endpoint( self, config: Optional[EndpointConfig] = None, **kwargs - ) -> None: ... + ) -> None: + ... class GroupManager(Protocol): @@ -318,22 +322,31 @@ class GroupManager(Protocol): """ @overload - def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... + def add_group( + self, *, name: str, queue_group: Optional[str] = None + ) -> Group: + ... @overload - def add_group(self, config: GroupConfig) -> Group: ... + def add_group(self, config: GroupConfig) -> Group: + ... - def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: ... + def add_group( + self, config: Optional[GroupConfig] = None, **kwargs + ) -> Group: + ... class Group(GroupManager, EndpointManager): + def __init__(self, service: "Service", config: GroupConfig) -> None: self._service = service self._prefix = config.name self._queue_group = config.queue_group @overload - async def add_endpoint(self, config: EndpointConfig) -> None: ... + async def add_endpoint(self, config: EndpointConfig) -> None: + ... @overload async def add_endpoint( @@ -344,7 +357,8 @@ async def add_endpoint( queue_group: Optional[str] = None, subject: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, - ) -> None: ... + ) -> None: + ... async def add_endpoint( self, config: Optional[EndpointConfig] = None, **kwargs @@ -356,21 +370,26 @@ async def add_endpoint( config = replace( config, - subject=f"{self._prefix.strip('.')}.{config.subject or config.name}".strip( - "." - ), + subject=f"{self._prefix.strip('.')}.{config.subject or config.name}" + .strip("."), queue_group=config.queue_group or self._queue_group, ) await self._service.add_endpoint(config) @overload - def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... + def add_group( + self, *, name: str, queue_group: Optional[str] = None + ) -> Group: + ... @overload - def add_group(self, config: GroupConfig) -> Group: ... + def add_group(self, config: GroupConfig) -> Group: + ... - def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: + def add_group( + self, config: Optional[GroupConfig] = None, **kwargs + ) -> Group: if config: config = replace(config, **kwargs) else: @@ -524,9 +543,12 @@ def from_dict(cls, data: Dict[str, Any]) -> ServiceStats: id=data["id"], name=data["name"], version=data["version"], - started=datetime.strptime(data["started"], "%Y-%m-%dT%H:%M:%S.%fZ"), + started=datetime.strptime( + data["started"], "%Y-%m-%dT%H:%M:%S.%fZ" + ), endpoints=[ - EndpointStats.from_dict(endpoint) for endpoint in data["endpoints"] + EndpointStats.from_dict(endpoint) + for endpoint in data["endpoints"] ], metadata=data["metadata"], ) @@ -539,7 +561,8 @@ def to_dict(self) -> Dict[str, Any]: "name": self.name, "id": self.id, "version": self.version, - "started": self.started.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", + "started": self.started.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "Z", "endpoints": [endpoint.to_dict() for endpoint in self.endpoints], "metadata": self.metadata, } @@ -599,7 +622,8 @@ def from_dict(cls, data: Dict[str, Any]) -> ServiceInfo: version=data["version"], description=data.get("description"), endpoints=[ - EndpointInfo.from_dict(endpoint) for endpoint in data["endpoints"] + EndpointInfo.from_dict(endpoint) + for endpoint in data["endpoints"] ], metadata=data["metadata"], type=data.get("type", "io.nats.micro.v1.info_response"), @@ -622,6 +646,7 @@ def to_dict(self) -> Dict[str, Any]: class Service(AsyncContextManager): + def __init__(self, client: Client, config: ServiceConfig) -> None: self._id = client._nuid.next().decode() self._name = config.name @@ -659,7 +684,9 @@ async def start(self) -> None: verb_subjects = [ ( f"{verb}-all", - control_subject(verb, name=None, id=None, prefix=self._prefix), + control_subject( + verb, name=None, id=None, prefix=self._prefix + ), ), ( f"{verb}-kind", @@ -670,7 +697,10 @@ async def start(self) -> None: ( verb, control_subject( - verb, name=self._name, id=self._id, prefix=self._prefix + verb, + name=self._name, + id=self._id, + prefix=self._prefix ), ), ] @@ -684,7 +714,8 @@ async def start(self) -> None: await self._client.flush() @overload - async def add_endpoint(self, config: EndpointConfig) -> None: ... + async def add_endpoint(self, config: EndpointConfig) -> None: + ... @overload async def add_endpoint( @@ -695,7 +726,8 @@ async def add_endpoint( queue_group: Optional[str] = None, subject: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, - ) -> None: ... + ) -> None: + ... async def add_endpoint( self, config: Optional[EndpointConfig] = None, **kwargs @@ -705,25 +737,35 @@ async def add_endpoint( else: config = replace(config, **kwargs) - config = replace(config, queue_group=config.queue_group or self._queue_group) + config = replace( + config, queue_group=config.queue_group or self._queue_group + ) endpoint = Endpoint(self, config) await endpoint._start() self._endpoints.append(endpoint) @overload - def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... + def add_group( + self, *, name: str, queue_group: Optional[str] = None + ) -> Group: + ... @overload - def add_group(self, config: GroupConfig) -> Group: ... + def add_group(self, config: GroupConfig) -> Group: + ... - def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: + def add_group( + self, config: Optional[GroupConfig] = None, **kwargs + ) -> Group: if config: config = replace(config, **kwargs) else: config = GroupConfig(**kwargs) - config = replace(config, queue_group=config.queue_group or self._queue_group) + config = replace( + config, queue_group=config.queue_group or self._queue_group + ) return Group(self, config) @@ -746,8 +788,7 @@ def stats(self) -> ServiceStats: last_error=endpoint._last_error, processing_time=endpoint._processing_time, average_processing_time=endpoint._average_processing_time, - ) - for endpoint in (self._endpoints or []) + ) for endpoint in (self._endpoints or []) ], started=self._started, ) @@ -771,8 +812,7 @@ def info(self) -> ServiceInfo: subject=endpoint._subject, queue_group=endpoint._queue_group, metadata=endpoint._metadata, - ) - for endpoint in self._endpoints + ) for endpoint in self._endpoints ], ) diff --git a/nats/nuid.py b/nats/nuid.py index e8f4470f..7896d642 100644 --- a/nats/nuid.py +++ b/nats/nuid.py @@ -18,7 +18,7 @@ from secrets import randbelow, token_bytes from sys import maxsize as MaxInt -DIGITS = b'0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +DIGITS = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" BASE = 62 PREFIX_LENGTH = 12 SEQ_LENGTH = 10 diff --git a/nats/protocol/command.py b/nats/protocol/command.py index 80573685..7218f48d 100644 --- a/nats/protocol/command.py +++ b/nats/protocol/command.py @@ -2,31 +2,35 @@ from typing import Callable -PUB_OP = 'PUB' -HPUB_OP = 'HPUB' -SUB_OP = 'SUB' -UNSUB_OP = 'UNSUB' -_CRLF_ = '\r\n' +PUB_OP = "PUB" +HPUB_OP = "HPUB" +SUB_OP = "SUB" +UNSUB_OP = "UNSUB" +_CRLF_ = "\r\n" Command = Callable[..., bytes] def pub_cmd(subject, reply, payload) -> bytes: - return f'{PUB_OP} {subject} {reply} {len(payload)}{_CRLF_}'.encode( - ) + payload + _CRLF_.encode() + return ( + f"{PUB_OP} {subject} {reply} {len(payload)}{_CRLF_}".encode() + + payload + _CRLF_.encode() + ) def hpub_cmd(subject, reply, hdr, payload) -> bytes: hdr_len = len(hdr) total_size = len(payload) + hdr_len - return f'{HPUB_OP} {subject} {reply} {hdr_len} {total_size}{_CRLF_}'.encode( - ) + hdr + payload + _CRLF_.encode() + return ( + f"{HPUB_OP} {subject} {reply} {hdr_len} {total_size}{_CRLF_}".encode() + + hdr + payload + _CRLF_.encode() + ) def sub_cmd(subject, queue, sid) -> bytes: - return f'{SUB_OP} {subject} {queue} {sid}{_CRLF_}'.encode() + return f"{SUB_OP} {subject} {queue} {sid}{_CRLF_}".encode() def unsub_cmd(sid, limit) -> bytes: - limit_s = '' if limit == 0 else f'{limit}' - return f'{UNSUB_OP} {sid} {limit_s}{_CRLF_}'.encode() + limit_s = "" if limit == 0 else f"{limit}" + return f"{UNSUB_OP} {sid} {limit_s}{_CRLF_}".encode() diff --git a/nats/protocol/parser.py b/nats/protocol/parser.py index 20bf3948..5335a455 100644 --- a/nats/protocol/parser.py +++ b/nats/protocol/parser.py @@ -24,31 +24,31 @@ from nats.errors import ProtocolError MSG_RE = re.compile( - b'\\AMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?(\\d+)\r\n' + b"\\AMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?(\\d+)\r\n" ) HMSG_RE = re.compile( - b'\\AHMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?([\\d]+)\\s+(\\d+)\r\n' + b"\\AHMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?([\\d]+)\\s+(\\d+)\r\n" ) -OK_RE = re.compile(b'\\A\\+OK\\s*\r\n') -ERR_RE = re.compile(b'\\A-ERR\\s+(\'.+\')?\r\n') -PING_RE = re.compile(b'\\APING\\s*\r\n') -PONG_RE = re.compile(b'\\APONG\\s*\r\n') -INFO_RE = re.compile(b'\\AINFO\\s+([^\r\n]+)\r\n') - -INFO_OP = b'INFO' -CONNECT_OP = b'CONNECT' -PUB_OP = b'PUB' -MSG_OP = b'MSG' -HMSG_OP = b'HMSG' -SUB_OP = b'SUB' -UNSUB_OP = b'UNSUB' -PING_OP = b'PING' -PONG_OP = b'PONG' -OK_OP = b'+OK' -ERR_OP = b'-ERR' -MSG_END = b'\n' -_CRLF_ = b'\r\n' -_SPC_ = b' ' +OK_RE = re.compile(b"\\A\\+OK\\s*\r\n") +ERR_RE = re.compile(b"\\A-ERR\\s+('.+')?\r\n") +PING_RE = re.compile(b"\\APING\\s*\r\n") +PONG_RE = re.compile(b"\\APONG\\s*\r\n") +INFO_RE = re.compile(b"\\AINFO\\s+([^\r\n]+)\r\n") + +INFO_OP = b"INFO" +CONNECT_OP = b"CONNECT" +PUB_OP = b"PUB" +MSG_OP = b"MSG" +HMSG_OP = b"HMSG" +SUB_OP = b"SUB" +UNSUB_OP = b"UNSUB" +PING_OP = b"PING" +PONG_OP = b"PONG" +OK_OP = b"+OK" +ERR_OP = b"-ERR" +MSG_END = b"\n" +_CRLF_ = b"\r\n" +_SPC_ = b" " OK = OK_OP + _CRLF_ PING = PING_OP + _CRLF_ @@ -87,7 +87,7 @@ def reset(self) -> None: self.header_needed = 0 self.msg_arg: Dict[str, Any] = {} - async def parse(self, data: bytes = b''): + async def parse(self, data: bytes = b""): """ Parses the wire protocol from NATS for the client and dispatches the subscription callbacks. @@ -104,7 +104,7 @@ async def parse(self, data: bytes = b''): if reply: self.msg_arg["reply"] = reply else: - self.msg_arg["reply"] = b'' + self.msg_arg["reply"] = b"" self.needed = int(needed_bytes) del self.buf[:msg.end()] self.state = AWAITING_MSG_PAYLOAD @@ -122,7 +122,7 @@ async def parse(self, data: bytes = b''): if reply: self.msg_arg["reply"] = reply else: - self.msg_arg["reply"] = b'' + self.msg_arg["reply"] = b"" self.needed = int(needed_bytes) self.header_needed = int(header_size) del self.buf[:msg.end()] @@ -165,7 +165,8 @@ async def parse(self, data: bytes = b''): del self.buf[:info.end()] continue - if len(self.buf) < MAX_CONTROL_LINE_SIZE and _CRLF_ in self.buf: + if len(self.buf + ) < MAX_CONTROL_LINE_SIZE and _CRLF_ in self.buf: # FIXME: By default server uses a max protocol # line of 4096 bytes but it can be tuned in latest # releases, in that case we won't reach here but diff --git a/setup.cfg b/setup.cfg index 6deafc26..ceead6c3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,3 @@ [flake8] +ignore = W503, W504 max-line-length = 120 diff --git a/setup.py b/setup.py index 44a204a5..7e8cc758 100644 --- a/setup.py +++ b/setup.py @@ -4,14 +4,14 @@ # These are here for GitHub's dependency graph and help with setuptools support in some environments. setup( name="nats-py", - version='2.9.0', - license='Apache 2 License', + version="2.9.0", + license="Apache 2 License", extras_require={ - 'nkeys': ['nkeys'], - 'aiohttp': ['aiohttp'], - 'fast_parse': ['fast-mail-parser'] + "nkeys": ["nkeys"], + "aiohttp": ["aiohttp"], + "fast_parse": ["fast-mail-parser"], }, - packages=['nats', 'nats.aio', 'nats.micro', 'nats.protocol', 'nats.js'], + packages=["nats", "nats.aio", "nats.micro", "nats.protocol", "nats.js"], package_data={"nats": ["py.typed"]}, - zip_safe=True + zip_safe=True, ) diff --git a/tests/test_client.py b/tests/test_client.py index 5548aa07..dcd96849 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -58,10 +58,10 @@ class ClientTest(SingleServerTestCase): @async_test async def test_default_connect(self): nc = await nats.connect() - self.assertIn('server_id', nc._server_info) - self.assertIn('client_id', nc._server_info) - self.assertIn('max_payload', nc._server_info) - self.assertEqual(nc._server_info['max_payload'], nc.max_payload) + self.assertIn("server_id", nc._server_info) + self.assertIn("client_id", nc._server_info) + self.assertIn("max_payload", nc._server_info) + self.assertEqual(nc._server_info["max_payload"], nc.max_payload) self.assertTrue(nc.max_payload > 0) self.assertTrue(nc.is_connected) self.assertTrue(nc.client_id > 0) @@ -214,34 +214,34 @@ async def test_publish(self): nc = NATS() await nc.connect() for i in range(0, 100): - await nc.publish(f"hello.{i}", b'A') + await nc.publish(f"hello.{i}", b"A") with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") await nc.flush() await nc.close() await asyncio.sleep(1) - self.assertEqual(100, nc.stats['out_msgs']) - self.assertEqual(100, nc.stats['out_bytes']) + self.assertEqual(100, nc.stats["out_msgs"]) + self.assertEqual(100, nc.stats["out_bytes"]) - endpoint = f'127.0.0.1:{self.server_pool[0].http_port}' + endpoint = f"127.0.0.1:{self.server_pool[0].http_port}" httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/varz') + httpclient.request("GET", "/varz") response = httpclient.getresponse() varz = json.loads((response.read()).decode()) - self.assertEqual(100, varz['in_msgs']) - self.assertEqual(100, varz['in_bytes']) + self.assertEqual(100, varz["in_msgs"]) + self.assertEqual(100, varz["in_bytes"]) @async_test async def test_flush(self): nc = NATS() await nc.connect() for i in range(0, 10): - await nc.publish(f"flush.{i}", b'AA') + await nc.publish(f"flush.{i}", b"AA") await nc.flush() - self.assertEqual(10, nc.stats['out_msgs']) - self.assertEqual(20, nc.stats['out_bytes']) + self.assertEqual(10, nc.stats["out_msgs"]) + self.assertEqual(20, nc.stats["out_bytes"]) await nc.close() @async_test @@ -252,14 +252,14 @@ async def test_subscribe(self): async def subscription_handler(msg): msgs.append(msg) - payload = b'hello world' + payload = b"hello world" await nc.connect() sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") # Validate some of the subjects with self.assertRaises(nats.errors.BadSubjectError): @@ -279,8 +279,8 @@ async def subscription_handler(msg): self.assertEqual(1, len(msgs)) msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(sub._id, msg.sid) self.assertEqual(1, sub._received) @@ -291,17 +291,17 @@ async def subscription_handler(msg): with self.assertRaises(KeyError): nc._subs[sub._id] - self.assertEqual(1, nc.stats['in_msgs']) - self.assertEqual(11, nc.stats['in_bytes']) - self.assertEqual(2, nc.stats['out_msgs']) - self.assertEqual(22, nc.stats['out_bytes']) + self.assertEqual(1, nc.stats["in_msgs"]) + self.assertEqual(11, nc.stats["in_bytes"]) + self.assertEqual(2, nc.stats["out_msgs"]) + self.assertEqual(22, nc.stats["out_bytes"]) - endpoint = f'127.0.0.1:{self.server_pool[0].http_port}' + endpoint = f"127.0.0.1:{self.server_pool[0].http_port}" httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/connz') + httpclient.request("GET", "/connz") response = httpclient.getresponse() connz = json.loads((response.read()).decode()) - self.assertEqual(0, len(connz['connections'])) + self.assertEqual(0, len(connz["connections"])) @async_test async def test_subscribe_functools_partial(self): @@ -320,7 +320,7 @@ async def subscription_handler(arg1, msg): subscription_handler, "example" ) - payload = b'hello world' + payload = b"hello world" await nc.connect() sid = await nc.subscribe("foo", cb=partial_sub_handler) @@ -331,8 +331,8 @@ async def subscription_handler(arg1, msg): self.assertEqual(1, len(msgs)) self.assertEqual(partial_arg, "example") msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(sid._id, msg.sid) @@ -361,7 +361,7 @@ async def subscription_handler2(msg): await nc.flush() await nc2.flush() - payload = b'hello world' + payload = b"hello world" for i in range(0, 10): await nc.publish("foo", payload) await asyncio.sleep(0.1) @@ -377,15 +377,15 @@ async def subscription_handler2(msg): await nc.close() await nc2.close() - self.assertEqual(0, nc.stats['in_msgs']) - self.assertEqual(0, nc.stats['in_bytes']) - self.assertEqual(10, nc.stats['out_msgs']) - self.assertEqual(110, nc.stats['out_bytes']) + self.assertEqual(0, nc.stats["in_msgs"]) + self.assertEqual(0, nc.stats["in_bytes"]) + self.assertEqual(10, nc.stats["out_msgs"]) + self.assertEqual(110, nc.stats["out_bytes"]) - self.assertEqual(10, nc2.stats['in_msgs']) - self.assertEqual(110, nc2.stats['in_bytes']) - self.assertEqual(0, nc2.stats['out_msgs']) - self.assertEqual(0, nc2.stats['out_bytes']) + self.assertEqual(10, nc2.stats["in_msgs"]) + self.assertEqual(110, nc2.stats["in_bytes"]) + self.assertEqual(0, nc2.stats["out_msgs"]) + self.assertEqual(0, nc2.stats["out_bytes"]) @async_test async def test_invalid_subscribe_error(self): @@ -426,7 +426,7 @@ async def subscription_handler(msg): await nc.subscribe("tests.>", cb=subscription_handler) for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # Wait a bit for messages to be received. await asyncio.sleep(1) @@ -447,14 +447,14 @@ async def iterator_func(sub): fut.set_result(None) await nc.connect() - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") self.assertFalse(sub._message_iterator._unsubscribed_future.done()) asyncio.ensure_future(iterator_func(sub)) self.assertFalse(sub._message_iterator._unsubscribed_future.done()) for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await asyncio.sleep(0) await asyncio.wait_for(sub.drain(), 1) @@ -477,12 +477,12 @@ async def test_subscribe_iterate_unsub_comprehension(self): await nc.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await sub.unsubscribe(limit=2) await nc.flush() for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # It should unblock once the auto unsubscribe limit is hit. msgs = [msg async for msg in sub.messages] @@ -500,12 +500,12 @@ async def test_subscribe_iterate_unsub(self): await nc.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await sub.unsubscribe(limit=2) await nc.flush() for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # A couple of messages would be received then this will unblock. msgs = [] @@ -528,12 +528,12 @@ async def test_subscribe_auto_unsub(self): async def handler(msg): msgs.append(msg) - sub = await nc.subscribe('tests.>', cb=handler) + sub = await nc.subscribe("tests.>", cb=handler) await sub.unsubscribe(limit=2) await nc.flush() for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await asyncio.sleep(1) self.assertEqual(2, len(msgs)) @@ -547,7 +547,7 @@ async def test_subscribe_iterate_next_msg(self): await nc.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await nc.flush() # Async generator to consume messages. @@ -561,7 +561,7 @@ async def next_msg(): return msg for i in range(0, 2): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # A couple of messages would be received then this will unblock. msg = await next_msg() @@ -576,9 +576,9 @@ async def next_msg(): # FIXME: This message would be lost because cannot # reuse the future from the iterator that timed out. - await nc.publish(f"tests.2", b'bar') + await nc.publish(f"tests.2", b"bar") - await nc.publish(f"tests.3", b'bar') + await nc.publish(f"tests.3", b"bar") await nc.flush() # FIXME: this test is flaky @@ -596,11 +596,11 @@ async def test_subscribe_next_msg(self): nc = await nats.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await nc.flush() for i in range(0, 2): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await nc.flush() # A couple of messages would be received then this will unblock. @@ -624,8 +624,8 @@ async def test_subscribe_next_msg(self): await sub.next_msg(timeout=0.5) # Send again a couple of messages. - await nc.publish(f"tests.2", b'bar') - await nc.publish(f"tests.3", b'bar') + await nc.publish(f"tests.2", b"bar") + await nc.publish(f"tests.3", b"bar") await nc.flush() msg = await sub.next_msg() self.assertEqual("tests.2", msg.subject) @@ -655,14 +655,14 @@ async def error_cb(err): nc = await nats.connect(error_cb=error_cb) sub = await nc.subscribe( - 'tests.>', + "tests.>", pending_msgs_limit=5, pending_bytes_limit=-1, ) await nc.flush() for i in range(0, 6): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await nc.flush() # A couple of messages would be received then this will unblock. @@ -690,13 +690,13 @@ async def test_subscribe_next_msg_with_cb_not_supported(self): nc = await nats.connect() async def handler(msg): - await msg.respond(b'OK') + await msg.respond(b"OK") - sub = await nc.subscribe('foo', cb=handler) + sub = await nc.subscribe("foo", cb=handler) await nc.flush() for i in range(0, 2): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await nc.flush() with self.assertRaises(nats.errors.Error): @@ -732,42 +732,42 @@ async def subscription_handler(msg): await nc.connect() sub = await nc.subscribe("foo", cb=subscription_handler) - await nc.publish("foo", b'A') - await nc.publish("foo", b'B') + await nc.publish("foo", b"A") + await nc.publish("foo", b"B") # Wait a bit to receive the messages await asyncio.sleep(0.5) self.assertEqual(2, len(msgs)) await sub.unsubscribe() - await nc.publish("foo", b'C') - await nc.publish("foo", b'D') + await nc.publish("foo", b"C") + await nc.publish("foo", b"D") # Ordering should be preserved in these at least - self.assertEqual(b'A', msgs[0].data) - self.assertEqual(b'B', msgs[1].data) + self.assertEqual(b"A", msgs[0].data) + self.assertEqual(b"B", msgs[1].data) # Should not exist by now with self.assertRaises(KeyError): nc._subs[sub._id].received await asyncio.sleep(1) - endpoint = f'127.0.0.1:{self.server_pool[0].http_port}' + endpoint = f"127.0.0.1:{self.server_pool[0].http_port}" httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/connz') + httpclient.request("GET", "/connz") response = httpclient.getresponse() connz = json.loads((response.read()).decode()) - self.assertEqual(1, len(connz['connections'])) - self.assertEqual(0, connz['connections'][0]['subscriptions']) - self.assertEqual(4, connz['connections'][0]['in_msgs']) - self.assertEqual(4, connz['connections'][0]['in_bytes']) - self.assertEqual(2, connz['connections'][0]['out_msgs']) - self.assertEqual(2, connz['connections'][0]['out_bytes']) + self.assertEqual(1, len(connz["connections"])) + self.assertEqual(0, connz["connections"][0]["subscriptions"]) + self.assertEqual(4, connz["connections"][0]["in_msgs"]) + self.assertEqual(4, connz["connections"][0]["in_bytes"]) + self.assertEqual(2, connz["connections"][0]["out_msgs"]) + self.assertEqual(2, connz["connections"][0]["out_bytes"]) await nc.close() - self.assertEqual(2, nc.stats['in_msgs']) - self.assertEqual(2, nc.stats['in_bytes']) - self.assertEqual(4, nc.stats['out_msgs']) - self.assertEqual(4, nc.stats['out_bytes']) + self.assertEqual(2, nc.stats["in_msgs"]) + self.assertEqual(2, nc.stats["in_bytes"]) + self.assertEqual(4, nc.stats["out_msgs"]) + self.assertEqual(4, nc.stats["out_bytes"]) @async_test async def test_old_style_request(self): @@ -779,32 +779,32 @@ async def worker_handler(msg): nonlocal counter counter += 1 msgs.append(msg) - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) async def slow_worker_handler(msg): await asyncio.sleep(0.5) - await nc.publish(msg.reply, b'timeout by now...') + await nc.publish(msg.reply, b"timeout by now...") await nc.connect() await nc.subscribe("help", cb=worker_handler) await nc.subscribe("slow.help", cb=slow_worker_handler) response = await nc.request( - "help", b'please', timeout=1, old_style=True + "help", b"please", timeout=1, old_style=True ) - self.assertEqual(b'Reply:1', response.data) + self.assertEqual(b"Reply:1", response.data) response = await nc.request( - "help", b'please', timeout=1, old_style=True + "help", b"please", timeout=1, old_style=True ) - self.assertEqual(b'Reply:2', response.data) + self.assertEqual(b"Reply:2", response.data) with self.assertRaises(nats.errors.TimeoutError): msg = await nc.request( - "slow.help", b'please', timeout=0.1, old_style=True + "slow.help", b"please", timeout=0.1, old_style=True ) with self.assertRaises(nats.errors.NoRespondersError): - await nc.request("nowhere", b'please', timeout=0.1, old_style=True) + await nc.request("nowhere", b"please", timeout=0.1, old_style=True) await asyncio.sleep(1) await nc.close() @@ -819,11 +819,11 @@ async def worker_handler(msg): nonlocal counter counter += 1 msgs.append(msg) - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) async def slow_worker_handler(msg): await asyncio.sleep(0.5) - await nc.publish(msg.reply, b'timeout by now...') + await nc.publish(msg.reply, b"timeout by now...") errs = [] @@ -834,13 +834,13 @@ async def err_cb(err): await nc.subscribe("help", cb=worker_handler) await nc.subscribe("slow.help", cb=slow_worker_handler) - response = await nc.request("help", b'please', timeout=1) - self.assertEqual(b'Reply:1', response.data) - response = await nc.request("help", b'please', timeout=1) - self.assertEqual(b'Reply:2', response.data) + response = await nc.request("help", b"please", timeout=1) + self.assertEqual(b"Reply:1", response.data) + response = await nc.request("help", b"please", timeout=1) + self.assertEqual(b"Reply:2", response.data) with self.assertRaises(nats.errors.TimeoutError): - await nc.request("slow.help", b'please', timeout=0.1) + await nc.request("slow.help", b"please", timeout=0.1) await asyncio.sleep(1) assert len(errs) == 0 await nc.close() @@ -850,18 +850,18 @@ async def test_requests_gather(self): nc = await nats.connect() async def worker_handler(msg): - await msg.respond(b'OK') + await msg.respond(b"OK") await nc.subscribe("foo.*", cb=worker_handler) - await nc.request("foo.A", b'') + await nc.request("foo.A", b"") msgs = await asyncio.gather( - nc.request("foo.B", b''), - nc.request("foo.C", b''), - nc.request("foo.D", b''), - nc.request("foo.E", b''), - nc.request("foo.F", b''), + nc.request("foo.B", b""), + nc.request("foo.C", b""), + nc.request("foo.D", b""), + nc.request("foo.E", b""), + nc.request("foo.F", b""), ) self.assertEqual(len(msgs), 5) @@ -872,12 +872,12 @@ async def test_custom_inbox_prefix(self): nc = NATS() async def worker_handler(msg): - self.assertTrue(msg.reply.startswith('bar.')) + self.assertTrue(msg.reply.startswith("bar.")) await msg.respond(b"OK") await nc.connect(inbox_prefix="bar") await nc.subscribe("foo", cb=worker_handler) - await nc.request("foo", b'') + await nc.request("foo", b"") await nc.close() @async_test @@ -886,15 +886,15 @@ async def test_msg_respond(self): msgs = [] async def cb1(msg): - await msg.respond(b'bar') + await msg.respond(b"bar") async def cb2(msg): msgs.append(msg) await nc.connect() - await nc.subscribe('subj1', cb=cb1) - await nc.subscribe('subj2', cb=cb2) - await nc.publish('subj1', b'foo', reply='subj2') + await nc.subscribe("subj1", cb=cb1) + await nc.subscribe("subj2", cb=cb2) + await nc.publish("subj1", b"foo", reply="subj2") await asyncio.sleep(1) self.assertEqual(1, len(msgs)) @@ -906,7 +906,7 @@ async def test_pending_data_size_tracking(self): await nc.connect() largest_pending_data_size = 0 for i in range(0, 100): - await nc.publish("example", b'A' * 100000) + await nc.publish("example", b"A" * 100000) if nc.pending_data_size > 0: largest_pending_data_size = nc.pending_data_size self.assertTrue(largest_pending_data_size > 0) @@ -938,23 +938,23 @@ async def err_cb(e): err_count += 1 options = { - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, } await nc.connect(**options) await nc.close() with self.assertRaises(nats.errors.ConnectionClosedError): - await nc.publish("foo", b'A') + await nc.publish("foo", b"A") with self.assertRaises(nats.errors.ConnectionClosedError): await nc.subscribe("bar", "workers") with self.assertRaises(nats.errors.ConnectionClosedError): - await nc.publish("bar", b'B', reply="inbox") + await nc.publish("bar", b"B", reply="inbox") with self.assertRaises(nats.errors.ConnectionClosedError): await nc.flush() @@ -998,11 +998,11 @@ async def closed_cb(): closed_count += 1 options = { - 'dont_randomize': True, - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'reconnect_time_wait': 0.01 + "dont_randomize": True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "reconnect_time_wait": 0.01, } await nc.connect(**options) @@ -1022,7 +1022,7 @@ async def receiver_cb(msg): await nc2.flush() for i in range(0, 200): - await nc.publish(f"example.{i}", b'A' * 20) + await nc.publish(f"example.{i}", b"A" * 20) # All pending messages should have been emitted to the server # by the first connection at this point. @@ -1041,15 +1041,15 @@ async def test_connect_with_auth(self): nc = NATS() options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ] } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) - self.assertIn('max_payload', nc._server_info) - self.assertEqual(nc._server_info['max_payload'], nc._max_payload) + self.assertIn("auth_required", nc._server_info) + self.assertIn("max_payload", nc._server_info) + self.assertEqual(nc._server_info["max_payload"], nc._max_payload) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1082,22 +1082,22 @@ async def err_cb(e): nc = NATS() options = { - 'reconnect_time_wait': 0.2, - 'servers': ["nats://hello:world@127.0.0.1:4223", ], - 'max_reconnect_attempts': 3, - 'error_cb': err_cb, + "reconnect_time_wait": 0.2, + "servers": ["nats://hello:world@127.0.0.1:4223", ], + "max_reconnect_attempts": 3, + "error_cb": err_cb, } try: await nc.connect(**options) except: pass - self.assertIn('auth_required', nc._server_info) - self.assertTrue(nc._server_info['auth_required']) + self.assertIn("auth_required", nc._server_info) + self.assertTrue(nc._server_info["auth_required"]) self.assertFalse(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) - self.assertEqual(0, nc.stats['reconnects']) + self.assertEqual(0, nc.stats["reconnects"]) self.assertTrue(len(errors) > 0) @async_test @@ -1109,10 +1109,10 @@ async def err_cb(e): errors.append(e) options = { - 'reconnect_time_wait': 0.2, - 'servers': ["nats://hello:world@127.0.0.1:4223", ], - 'max_reconnect_attempts': 3, - 'error_cb': err_cb, + "reconnect_time_wait": 0.2, + "servers": ["nats://hello:world@127.0.0.1:4223", ], + "max_reconnect_attempts": 3, + "error_cb": err_cb, } with self.assertRaises(nats.errors.NoServersError): await nats.connect(**options) @@ -1134,20 +1134,20 @@ async def err_cb(e): errors.append(e) options = { - 'dont_randomize': True, - 'reconnect_time_wait': 0.5, - 'disconnected_cb': disconnected_cb, - 'error_cb': err_cb, - 'servers': [ + "dont_randomize": True, + "reconnect_time_wait": 0.5, + "disconnected_cb": disconnected_cb, + "error_cb": err_cb, + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'max_reconnect_attempts': -1 + "max_reconnect_attempts": -1, } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) - self.assertTrue(nc._server_info['auth_required']) + self.assertIn("auth_required", nc._server_info) + self.assertTrue(nc._server_info["auth_required"]) self.assertTrue(nc.is_connected) # Stop all servers so that there aren't any available to reconnect @@ -1179,7 +1179,7 @@ async def err_cb(e): # Many attempts but only at most 2 reconnects would have occurred, # in case it was able to reconnect to another server while it was # shutting down. - self.assertTrue(nc.stats['reconnects'] >= 1) + self.assertTrue(nc.stats["reconnects"] >= 1) # Wrap off and disconnect await nc.close() @@ -1207,23 +1207,23 @@ async def err_cb(e): errors.append(e) options = { - 'dont_randomize': True, - 'reconnect_time_wait': 0.5, - 'disconnected_cb': disconnected_cb, - 'error_cb': err_cb, - 'closed_cb': closed_cb, - 'servers': [ + "dont_randomize": True, + "reconnect_time_wait": 0.5, + "disconnected_cb": disconnected_cb, + "error_cb": err_cb, + "closed_cb": closed_cb, + "servers": [ "nats://foo:bar@127.0.0.1:4223", "nats://hoge:fuga@127.0.0.1:4224", - "nats://hello:world@127.0.0.1:4225" + "nats://hello:world@127.0.0.1:4225", ], - 'max_reconnect_attempts': 3, - 'dont_randomize': True + "max_reconnect_attempts": 3, + "dont_randomize": True, } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) - self.assertTrue(nc._server_info['auth_required']) + self.assertIn("auth_required", nc._server_info) + self.assertTrue(nc._server_info["auth_required"]) self.assertTrue(nc.is_connected) # Check number of nodes in the server pool. @@ -1267,7 +1267,7 @@ async def err_cb(e): await asyncio.sleep(0) # Only reconnected successfully once to the same server. - self.assertTrue(nc.stats['reconnects'] == 1) + self.assertTrue(nc.stats["reconnects"] == 1) self.assertEqual(1, len(nc._server_pool)) # await nc.close() @@ -1297,13 +1297,13 @@ async def err_cb(e): errors.append(e) options = { - 'dont_randomize': True, - 'reconnect_time_wait': 0.5, - 'disconnected_cb': disconnected_cb, - 'error_cb': err_cb, - 'closed_cb': closed_cb, - 'max_reconnect_attempts': 3, - 'dont_randomize': True, + "dont_randomize": True, + "reconnect_time_wait": 0.5, + "disconnected_cb": disconnected_cb, + "error_cb": err_cb, + "closed_cb": closed_cb, + "max_reconnect_attempts": 3, + "dont_randomize": True, } nc = await nats.connect("nats://foo:bar@127.0.0.1:4223", **options) @@ -1343,15 +1343,15 @@ async def closed_cb(): closed_count += 1 options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'dont_randomize': True, - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'reconnect_time_wait': 0.01 + "dont_randomize": True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "reconnect_time_wait": 0.01, } await nc.connect(**options) largest_pending_data_size = 0 @@ -1364,7 +1364,7 @@ async def cb(msg): await nc.subscribe("example.*", cb=cb) for i in range(0, 200): - await nc.publish(f"example.{i}", b'A' * 20) + await nc.publish(f"example.{i}", b"A" * 20) if nc.pending_data_size > 0: largest_pending_data_size = nc.pending_data_size if nc.pending_data_size > 100: @@ -1385,7 +1385,7 @@ async def cb(msg): await asyncio.sleep(0) await asyncio.sleep(0.2) await asyncio.sleep(0) - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) try: await nc.flush(2) except nats.errors.TimeoutError: @@ -1419,16 +1419,16 @@ async def closed_cb(): closed_count += 1 options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'dont_randomize': True, - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'flusher_queue_size': 100, - 'reconnect_time_wait': 0.01 + "dont_randomize": True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "flusher_queue_size": 100, + "reconnect_time_wait": 0.01, } await nc.connect(**options) largest_pending_data_size = 0 @@ -1441,7 +1441,7 @@ async def cb(msg): await nc.subscribe("example.*", cb=cb) for i in range(0, 500): - await nc.publish(f"example.{i}", b'A' * 20) + await nc.publish(f"example.{i}", b"A" * 20) if nc.pending_data_size > 0: largest_pending_data_size = nc.pending_data_size if nc.pending_data_size > 100: @@ -1462,7 +1462,7 @@ async def cb(msg): await asyncio.sleep(0) await asyncio.sleep(0.2) await asyncio.sleep(0) - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) try: await nc.flush(2) except nats.errors.TimeoutError: @@ -1500,15 +1500,15 @@ async def err_cb(e): errors.append(e) options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, - 'dont_randomize': True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, + "dont_randomize": True, } nc = await nats.connect(**options) self.assertTrue(nc.is_connected) @@ -1517,14 +1517,14 @@ async def worker_handler(msg): nonlocal counter counter += 1 if msg.reply != "": - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) await nc.subscribe("one", cb=worker_handler) await nc.subscribe("two", cb=worker_handler) await nc.subscribe("three", cb=worker_handler) - response = await nc.request("one", b'Help!', timeout=1) - self.assertEqual(b'Reply:1', response.data) + response = await nc.request("one", b"Help!", timeout=1) + self.assertEqual(b"Reply:1", response.data) # Stop the first server and connect to another one asap. asyncio.get_running_loop().run_in_executor( @@ -1534,11 +1534,11 @@ async def worker_handler(msg): # FIXME: Find better way to wait for the server to be stopped. await asyncio.sleep(0.5) - response = await nc.request("three", b'Help!', timeout=1) - self.assertEqual(b'Reply:2', response.data) + response = await nc.request("three", b"Help!", timeout=1) + self.assertEqual(b"Reply:2", response.data) await asyncio.sleep(0.5) await nc.close() - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) self.assertEqual(1, closed_count) self.assertEqual(2, disconnected_count) self.assertEqual(1, reconnected_count) @@ -1552,9 +1552,9 @@ class ClientAuthTokenTest(MultiServerAuthTokenTestCase): async def test_connect_with_auth_token(self): nc = NATS() - options = {'servers': ["nats://token@127.0.0.1:4223", ]} + options = {"servers": ["nats://token@127.0.0.1:4223", ]} await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1565,11 +1565,11 @@ async def test_connect_with_auth_token_option(self): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4223", ], - 'token': "token", + "servers": ["nats://127.0.0.1:4223", ], + "token": "token", } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1580,16 +1580,16 @@ async def test_connect_with_bad_auth_token(self): nc = NATS() options = { - 'servers': ["nats://token@127.0.0.1:4225", ], - 'allow_reconnect': False, - 'reconnect_time_wait': 0.1, - 'max_reconnect_attempts': 1, + "servers": ["nats://token@127.0.0.1:4225", ], + "allow_reconnect": False, + "reconnect_time_wait": 0.1, + "max_reconnect_attempts": 1, } # Authorization Violation with self.assertRaises(nats.errors.Error): await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertFalse(nc.is_connected) @async_test @@ -1619,21 +1619,21 @@ async def worker_handler(msg): nonlocal counter counter += 1 if msg.reply != "": - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) options = { - 'servers': [ + "servers": [ "nats://token@127.0.0.1:4223", "nats://token@127.0.0.1:4224", ], - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'dont_randomize': True + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "dont_randomize": True, } await nc.connect(**options) await nc.subscribe("test", cb=worker_handler) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) # Trigger a reconnect @@ -1643,8 +1643,8 @@ async def worker_handler(msg): await asyncio.sleep(1) await nc.subscribe("test", cb=worker_handler) - response = await nc.request("test", b'data', timeout=1) - self.assertEqual(b'Reply:1', response.data) + response = await nc.request("test", b"data", timeout=1) + self.assertEqual(b"Reply:1", response.data) await nc.close() self.assertTrue(nc.is_closed) @@ -1659,10 +1659,10 @@ class ClientTLSTest(TLSServerTestCase): @async_test async def test_connect(self): nc = NATS() - await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx) - self.assertEqual(nc._server_info['max_payload'], nc.max_payload) - self.assertTrue(nc._server_info['tls_required']) - self.assertTrue(nc._server_info['tls_verify']) + await nc.connect(servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx) + self.assertEqual(nc._server_info["max_payload"], nc.max_payload) + self.assertTrue(nc._server_info["tls_required"]) + self.assertTrue(nc._server_info["tls_verify"]) self.assertTrue(nc.max_payload > 0) self.assertTrue(nc.is_connected) await nc.close() @@ -1676,7 +1676,7 @@ async def test_default_connect_using_tls_scheme(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['tls://127.0.0.1:4224'], allow_reconnect=False + servers=["tls://127.0.0.1:4224"], allow_reconnect=False ) @async_test @@ -1685,7 +1685,7 @@ async def test_default_connect_using_tls_scheme_in_url(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect('tls://127.0.0.1:4224', allow_reconnect=False) + await nc.connect("tls://127.0.0.1:4224", allow_reconnect=False) @async_test async def test_connect_tls_with_custom_hostname(self): @@ -1694,7 +1694,7 @@ async def test_connect_tls_with_custom_hostname(self): # Will attempt to connect using TLS with an invalid hostname. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['nats://127.0.0.1:4224'], + servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx, tls_hostname="nats.example", allow_reconnect=False, @@ -1708,22 +1708,22 @@ async def test_subscribe(self): async def subscription_handler(msg): msgs.append(msg) - payload = b'hello world' - await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx) + payload = b"hello world" + await nc.connect(servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx) sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") # Wait a bit for message to be received. await asyncio.sleep(0.2) self.assertEqual(1, len(msgs)) msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(1, sub._received) await nc.close() @@ -1733,7 +1733,6 @@ class ClientTLSReconnectTest(MultiTLSServerAuthTestCase): @async_test async def test_tls_reconnect(self): - nc = NATS() disconnected_count = 0 reconnected_count = 0 @@ -1762,26 +1761,26 @@ async def worker_handler(msg): nonlocal counter counter += 1 if msg.reply != "": - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, - 'dont_randomize': True, - 'tls': self.ssl_ctx + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, + "dont_randomize": True, + "tls": self.ssl_ctx, } await nc.connect(**options) self.assertTrue(nc.is_connected) await nc.subscribe("example", cb=worker_handler) - response = await nc.request("example", b'Help!', timeout=1) - self.assertEqual(b'Reply:1', response.data) + response = await nc.request("example", b"Help!", timeout=1) + self.assertEqual(b"Reply:1", response.data) # Trigger a reconnect and should be fine await asyncio.get_running_loop().run_in_executor( @@ -1790,13 +1789,13 @@ async def worker_handler(msg): await asyncio.sleep(1) await nc.subscribe("example", cb=worker_handler) - response = await nc.request("example", b'Help!', timeout=1) - self.assertEqual(b'Reply:2', response.data) + response = await nc.request("example", b"Help!", timeout=1) + self.assertEqual(b"Reply:2", response.data) await nc.close() self.assertTrue(nc.is_closed) self.assertFalse(nc.is_connected) - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) self.assertEqual(1, closed_count) self.assertEqual(2, disconnected_count) self.assertEqual(1, reconnected_count) @@ -1807,17 +1806,17 @@ class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase): @async_test async def test_connect(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = await nats.connect( - 'nats://127.0.0.1:4224', + "nats://127.0.0.1:4224", tls=self.ssl_ctx, tls_handshake_first=True ) - self.assertEqual(nc._server_info['max_payload'], nc.max_payload) - self.assertTrue(nc._server_info['tls_required']) - self.assertTrue(nc._server_info['tls_verify']) + self.assertEqual(nc._server_info["max_payload"], nc.max_payload) + self.assertTrue(nc._server_info["tls_required"]) + self.assertTrue(nc._server_info["tls_verify"]) self.assertTrue(nc.max_payload > 0) self.assertTrue(nc.is_connected) await nc.close() @@ -1826,7 +1825,7 @@ async def test_connect(self): @async_test async def test_default_connect_using_tls_scheme(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1834,14 +1833,14 @@ async def test_default_connect_using_tls_scheme(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['tls://127.0.0.1:4224'], + servers=["tls://127.0.0.1:4224"], allow_reconnect=False, tls_handshake_first=True, ) @async_test async def test_default_connect_using_tls_scheme_in_url(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1849,14 +1848,14 @@ async def test_default_connect_using_tls_scheme_in_url(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - 'tls://127.0.0.1:4224', + "tls://127.0.0.1:4224", allow_reconnect=False, tls_handshake_first=True ) @async_test async def test_connect_tls_with_custom_hostname(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1864,7 +1863,7 @@ async def test_connect_tls_with_custom_hostname(self): # Will attempt to connect using TLS with an invalid hostname. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['nats://127.0.0.1:4224'], + servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx, tls_hostname="nats.example", tls_handshake_first=True, @@ -1873,7 +1872,7 @@ async def test_connect_tls_with_custom_hostname(self): @async_test async def test_subscribe(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1882,26 +1881,26 @@ async def test_subscribe(self): async def subscription_handler(msg): msgs.append(msg) - payload = b'hello world' + payload = b"hello world" await nc.connect( - servers=['nats://127.0.0.1:4224'], + servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx, - tls_handshake_first=True + tls_handshake_first=True, ) sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") # Wait a bit for message to be received. await asyncio.sleep(0.2) self.assertEqual(1, len(msgs)) msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(1, sub._received) await nc.close() @@ -1924,11 +1923,11 @@ async def test_discover_servers_on_first_connect(self): ) await asyncio.sleep(1) - options = {'servers': ["nats://127.0.0.1:4223", ]} + options = {"servers": ["nats://127.0.0.1:4223", ]} discovered_server_cb = mock.Mock() - with mock.patch('asyncio.iscoroutinefunction', return_value=True): + with mock.patch("asyncio.iscoroutinefunction", return_value=True): await nc.connect( **options, discovered_server_cb=discovered_server_cb ) @@ -1943,9 +1942,9 @@ async def test_discover_servers_on_first_connect(self): async def test_discover_servers_after_first_connect(self): nc = NATS() - options = {'servers': ["nats://127.0.0.1:4223", ]} + options = {"servers": ["nats://127.0.0.1:4223", ]} discovered_server_cb = mock.Mock() - with mock.patch('asyncio.iscoroutinefunction', return_value=True): + with mock.patch("asyncio.iscoroutinefunction", return_value=True): await nc.connect( **options, discovered_server_cb=discovered_server_cb ) @@ -1985,12 +1984,12 @@ async def err_cb(e): errors.append(e) options = { - 'servers': ["nats://127.0.0.1:4223", ], - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, - 'reconnect_time_wait': 0.1, - 'user': "foo", - 'password': "bar", + "servers": ["nats://127.0.0.1:4223", ], + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, + "reconnect_time_wait": 0.1, + "user": "foo", + "password": "bar", } await nc.connect(**options) @@ -1998,7 +1997,7 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) @@ -2008,8 +2007,8 @@ async def handler(msg): ) await asyncio.wait_for(reconnected, 2) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) await nc.close() self.assertTrue(nc.is_closed) @@ -2058,12 +2057,12 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect await asyncio.get_running_loop().run_in_executor( @@ -2073,10 +2072,10 @@ async def handler(msg): # Publishing while disconnected is an error if pending size is disabled. with self.assertRaises(nats.errors.OutboundBufferLimitError): - await nc.request("foo", b'hi') + await nc.request("foo", b"hi") with self.assertRaises(nats.errors.OutboundBufferLimitError): - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") await asyncio.wait_for(reconnected, 2) @@ -2128,12 +2127,12 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect await asyncio.get_running_loop().run_in_executor( @@ -2142,7 +2141,7 @@ async def handler(msg): await asyncio.wait_for(disconnected, 2) # While reconnecting the pending data will accumulate. - await nc.publish("foo", b'bar') + await nc.publish("foo", b"bar") self.assertEqual(nc._pending_data_size, 17) # Publishing while disconnected is an error if it hits the pending size. @@ -2197,12 +2196,12 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Publishing while connected should trigger a force flush. payload = ("A" * 1025).encode() @@ -2255,12 +2254,12 @@ async def err_cb(e): async def handler(msg): # This becomes an async error if msg.reply: - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Publishing while connected should trigger a force flush. payload = ("A" * 1025).encode() @@ -2286,14 +2285,14 @@ class ConnectFailuresTest(SingleServerTestCase): async def test_empty_info_op_uses_defaults(self): async def bad_server(reader, writer): - writer.write(b'INFO {}\r\n') + writer.write(b"INFO {}\r\n") await writer.drain() data = await reader.readline() await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) disconnected_count = 0 @@ -2303,8 +2302,8 @@ async def disconnected_cb(): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'disconnected_cb': disconnected_cb + "servers": ["nats://127.0.0.1:4555", ], + "disconnected_cb": disconnected_cb, } await nc.connect(**options) self.assertEqual(nc.max_payload, 1048576) @@ -2316,11 +2315,11 @@ async def disconnected_cb(): async def test_empty_response_from_server(self): async def bad_server(reader, writer): - writer.write(b'') + writer.write(b"") await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) errors = [] @@ -2330,9 +2329,9 @@ async def error_cb(e): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'error_cb': error_cb, - 'allow_reconnect': False, + "servers": ["nats://127.0.0.1:4555", ], + "error_cb": error_cb, + "allow_reconnect": False, } with self.assertRaises(nats.errors.Error): @@ -2344,11 +2343,11 @@ async def error_cb(e): async def test_malformed_info_response_from_server(self): async def bad_server(reader, writer): - writer.write(b'INF') + writer.write(b"INF") await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) errors = [] @@ -2358,9 +2357,9 @@ async def error_cb(e): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'error_cb': error_cb, - 'allow_reconnect': False, + "servers": ["nats://127.0.0.1:4555", ], + "error_cb": error_cb, + "allow_reconnect": False, } with self.assertRaises(nats.errors.Error): @@ -2372,11 +2371,11 @@ async def error_cb(e): async def test_malformed_info_json_response_from_server(self): async def bad_server(reader, writer): - writer.write(b'INFO {\r\n') + writer.write(b"INFO {\r\n") await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) errors = [] @@ -2386,9 +2385,9 @@ async def error_cb(e): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'error_cb': error_cb, - 'allow_reconnect': False, + "servers": ["nats://127.0.0.1:4555", ], + "error_cb": error_cb, + "allow_reconnect": False, } with self.assertRaises(nats.errors.Error): @@ -2404,7 +2403,7 @@ async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() - await asyncio.start_server(slow_server, '127.0.0.1', 4555) + await asyncio.start_server(slow_server, "127.0.0.1", 4555) disconnected_count = 0 reconnected = asyncio.Future() @@ -2419,12 +2418,12 @@ async def reconnected_cb(): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'disconnected_cb': disconnected_cb, - 'reconnected_cb': reconnected_cb, - 'connect_timeout': 0.5, - 'dont_randomize': True, - 'allow_reconnect': False, + "servers": ["nats://127.0.0.1:4555", ], + "disconnected_cb": disconnected_cb, + "reconnected_cb": reconnected_cb, + "connect_timeout": 0.5, + "dont_randomize": True, + "allow_reconnect": False, } with self.assertRaises(asyncio.TimeoutError): @@ -2441,7 +2440,7 @@ async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() - await asyncio.start_server(slow_server, '127.0.0.1', 4555) + await asyncio.start_server(slow_server, "127.0.0.1", 4555) disconnected_count = 0 reconnected = asyncio.Future() @@ -2462,15 +2461,15 @@ async def error_cb(e): nc = NATS() options = { - 'servers': [ + "servers": [ "nats://127.0.0.1:4555", "nats://127.0.0.1:4222", ], - 'disconnected_cb': disconnected_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': error_cb, - 'connect_timeout': 0.5, - 'dont_randomize': True, + "disconnected_cb": disconnected_cb, + "reconnected_cb": reconnected_cb, + "error_cb": error_cb, + "connect_timeout": 0.5, + "dont_randomize": True, } await nc.connect(**options) @@ -2479,7 +2478,7 @@ async def error_cb(e): self.assertTrue(nc.is_connected) for i in range(0, 10): - await nc.publish("foo", b'ok ok') + await nc.publish("foo", b"ok ok") await nc.flush() await nc.close() @@ -2528,7 +2527,7 @@ async def handler(msg): sub = await nc.subscribe("foo", cb=handler) for i in range(0, 200): - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") # Relinquish control so that messages are processed. await asyncio.sleep(0) @@ -2542,7 +2541,7 @@ async def handler(msg): await asyncio.wait_for(drain_task, 1) for i in range(0, 200): - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") # Relinquish control so that messages are processed. await asyncio.sleep(0) @@ -2564,7 +2563,7 @@ async def test_drain_subscription_with_future(self): fut = asyncio.Future() sub = await nc.subscribe("foo", future=fut) - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") await nc.flush() drain_task = sub.drain() @@ -2599,7 +2598,7 @@ async def closed_cb(): async def handler(msg): if len(msgs) % 20 == 1: await asyncio.sleep(0.2) - await nc.publish(msg.reply, b'OK!', reply=msg.subject) + await nc.publish(msg.reply, b"OK!", reply=msg.subject) await nc2.flush() sub_foo = await nc.subscribe("foo", cb=handler) @@ -2613,14 +2612,14 @@ async def replies(msg): await nc2.subscribe("my-replies.*", cb=replies) for i in range(0, 201): await nc2.publish( - "foo", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "bar", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( "quux", - b'help', + b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) @@ -2695,7 +2694,7 @@ async def handler(msg): await asyncio.sleep(0.2) if len(msgs) % 50 == 1: await asyncio.sleep(0.5) - await nc.publish(msg.reply, b'OK!', reply=msg.subject) + await nc.publish(msg.reply, b"OK!", reply=msg.subject) await nc2.flush() await nc.subscribe("foo", cb=handler) @@ -2709,14 +2708,14 @@ async def replies(msg): await nc2.subscribe("my-replies.*", cb=replies) for i in range(0, 201): await nc2.publish( - "foo", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "bar", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( "quux", - b'help', + b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) @@ -2744,15 +2743,19 @@ async def test_non_async_callbacks_raise_error(self): def f(): pass - for cb in ['error_cb', 'disconnected_cb', 'discovered_server_cb', - 'closed_cb', 'reconnected_cb']: - + for cb in [ + "error_cb", + "disconnected_cb", + "discovered_server_cb", + "closed_cb", + "reconnected_cb", + ]: with self.assertRaises(nats.errors.InvalidCallbackTypeError): await nc.connect( servers=["nats://127.0.0.1:4222"], max_reconnect_attempts=2, reconnect_time_wait=0.2, - **{cb: f} + **{cb: f}, ) @async_test @@ -2792,7 +2795,8 @@ async def cb(msg): with unittest.mock.patch( "asyncio.wait_for", unittest.mock.AsyncMock(side_effect=asyncio.CancelledError - )): + ), + ): await sub.drain() await nc.close() @@ -2816,7 +2820,7 @@ async def err_cb(e): sub = await nc.subscribe("foo") await nc.flush() await asyncio.sleep(0) - await nc.publish("foo", b'hello') + await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() assert str( @@ -2826,12 +2830,12 @@ async def err_cb(e): nc2 = await nats.connect("nats://127.0.0.1:4555", ) async def cb(msg): - await msg.respond(b'pong') + await msg.respond(b"pong") sub2 = await nc2.subscribe("foo", cb=cb) await nc2.flush() - resp = await nc2.request("foo", b'ping') - assert resp.data == b'pong' + resp = await nc2.request("foo", b"ping") + assert resp.data == b"pong" await nc2.close() await nc.close() @@ -2851,7 +2855,7 @@ async def err_cb(e): sub = await nc.subscribe("foo") await nc.flush() await asyncio.sleep(0) - await nc.publish("foo", b'hello') + await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() assert str( @@ -2861,12 +2865,12 @@ async def err_cb(e): nc2 = await nats.connect("nats://127.0.0.1:4555", ) async def cb(msg): - await msg.respond(b'pong') + await msg.respond(b"pong") sub2 = await nc2.subscribe("foo", cb=cb) await nc2.flush() - resp = await nc2.request("foo", b'ping') - assert resp.data == b'pong' + resp = await nc2.request("foo", b"ping") + assert resp.data == b"pong" await nc2.close() await nc.close() @@ -2900,10 +2904,10 @@ async def disconnected_cb(): error_cb=error_cb, ) sub = await nc.subscribe("foo") - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') + self.assertEqual(msg.data, b"First") await asyncio.get_running_loop().run_in_executor( None, self.server_pool[0].stop @@ -2915,7 +2919,8 @@ async def disconnected_cb(): disconnected_states[1] == NATS.CLOSED -if __name__ == '__main__': +if __name__ == "__main__": import sys + runner = unittest.TextTestRunner(stream=sys.stdout) unittest.main(verbosity=2, exit=False, testRunner=runner) diff --git a/tests/test_client_async_await.py b/tests/test_client_async_await.py index 05da8bd6..71ba7a45 100644 --- a/tests/test_client_async_await.py +++ b/tests/test_client_async_await.py @@ -27,7 +27,7 @@ async def subscription_handler(msg): self.assertIsNotNone(sub) for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # Wait a bit for messages to be received. await asyncio.sleep(1) @@ -60,34 +60,34 @@ async def handler_foo(msg): async def handler_bar(msg): msgs.append(msg) if msg.reply != "": - await nc.publish(msg.reply, b'') + await nc.publish(msg.reply, b"") await nc.subscribe("bar", cb=handler_bar) - await nc.publish("foo", b'1') - await nc.publish("foo", b'2') - await nc.publish("foo", b'3') + await nc.publish("foo", b"1") + await nc.publish("foo", b"2") + await nc.publish("foo", b"3") # Will be processed before the others since no head of line # blocking among the subscriptions. - await nc.publish("bar", b'4') + await nc.publish("bar", b"4") - response = await nc.request("foo", b'hello1', timeout=1) - self.assertEqual(response.data, b'hello1hello1') + response = await nc.request("foo", b"hello1", timeout=1) + self.assertEqual(response.data, b"hello1hello1") with self.assertRaises(TimeoutError): - await nc.request("foo", b'hello2', timeout=0.1) - - await nc.publish("bar", b'5') - response = await nc.request("foo", b'hello2', timeout=1) - self.assertEqual(response.data, b'hello2hello2') - - self.assertEqual(msgs[0].data, b'1') - self.assertEqual(msgs[1].data, b'4') - self.assertEqual(msgs[2].data, b'2') - self.assertEqual(msgs[3].data, b'3') - self.assertEqual(msgs[4].data, b'hello1') - self.assertEqual(msgs[5].data, b'hello2') + await nc.request("foo", b"hello2", timeout=0.1) + + await nc.publish("bar", b"5") + response = await nc.request("foo", b"hello2", timeout=1) + self.assertEqual(response.data, b"hello2hello2") + + self.assertEqual(msgs[0].data, b"1") + self.assertEqual(msgs[1].data, b"4") + self.assertEqual(msgs[2].data, b"2") + self.assertEqual(msgs[3].data, b"3") + self.assertEqual(msgs[4].data, b"hello1") + self.assertEqual(msgs[5].data, b"hello2") self.assertEqual(len(errors), 0) await nc.close() @@ -120,17 +120,17 @@ async def handler_bar(msg): await nc.subscribe("bar", cb=handler_bar) for i in range(10): - await nc.publish("foo", f'{i}'.encode()) + await nc.publish("foo", f"{i}".encode()) # Will be processed before the others since no head of line # blocking among the subscriptions. - await nc.publish("bar", b'14') - response = await nc.request("bar", b'hi1', 2) - self.assertEqual(response.data, b'hi1hi1hi1') + await nc.publish("bar", b"14") + response = await nc.request("bar", b"hi1", 2) + self.assertEqual(response.data, b"hi1hi1hi1") self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0].data, b'14') - self.assertEqual(msgs[1].data, b'hi1') + self.assertEqual(msgs[0].data, b"14") + self.assertEqual(msgs[1].data, b"hi1") # Consumed messages but the rest were slow consumers. self.assertTrue(4 <= len(errors) <= 5) @@ -174,13 +174,13 @@ async def handler_bar(msg): # Will be processed before the others since no head of line # blocking among the subscriptions. - await nc.publish("bar", b'14') + await nc.publish("bar", b"14") - response = await nc.request("bar", b'hi1', 2) - self.assertEqual(response.data, b'hi1hi1hi1') + response = await nc.request("bar", b"hi1", 2) + self.assertEqual(response.data, b"hi1hi1hi1") self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0].data, b'14') - self.assertEqual(msgs[1].data, b'hi1') + self.assertEqual(msgs[0].data, b"14") + self.assertEqual(msgs[1].data, b"hi1") # Consumed a few messages but the rest were slow consumers. self.assertTrue(7 <= len(errors) <= 8) @@ -191,6 +191,6 @@ async def handler_bar(msg): self.assertEqual(errors[0].reply, "") # Try again a few seconds later and it should have recovered await asyncio.sleep(3) - response = await nc.request("foo", b'B', 1) - self.assertEqual(response.data, b'BB') + response = await nc.request("foo", b"B", 1) + self.assertEqual(response.data, b"BB") await nc.close() diff --git a/tests/test_client_nkeys.py b/tests/test_client_nkeys.py index 6161bd0f..6fa785d5 100644 --- a/tests/test_client_nkeys.py +++ b/tests/test_client_nkeys.py @@ -6,6 +6,7 @@ try: import nkeys + nkeys_installed = True except ModuleNotFoundError: nkeys_installed = False @@ -30,15 +31,15 @@ async def test_nkeys_connect(self): config_file = get_config_file("nkeys/foo-user.nk") seed = None - with open(config_file, 'rb') as f: + with open(config_file, "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) args_list = [ { - 'nkeys_seed': config_file + "nkeys_seed": config_file }, { - 'nkeys_seed_str': seed.decode() + "nkeys_seed_str": seed.decode() }, ] for nkeys_args in args_list: @@ -53,27 +54,29 @@ async def error_cb(e): nonlocal future future.set_result(True) - await nc.connect(["tls://127.0.0.1:4222"], - error_cb=error_cb, - connect_timeout=10, - allow_reconnect=False, - **nkeys_args) + await nc.connect( + ["tls://127.0.0.1:4222"], + error_cb=error_cb, + connect_timeout=10, + allow_reconnect=False, + **nkeys_args, + ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.subscribe("bar", cb=help_handler) await nc.flush() await asyncio.wait_for(future, 1) - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() @@ -99,12 +102,12 @@ async def error_cb(e): ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() @async_test @@ -123,18 +126,18 @@ async def error_cb(e): connect_timeout=5, user_credentials=( get_config_file("nkeys/foo-user.jwt"), - get_config_file("nkeys/foo-user.nk") + get_config_file("nkeys/foo-user.nk"), ), allow_reconnect=False, ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() @async_test @@ -197,10 +200,10 @@ async def error_cb(e): ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() diff --git a/tests/test_client_v2.py b/tests/test_client_v2.py index b78765ae..fad31dce 100644 --- a/tests/test_client_v2.py +++ b/tests/test_client_v2.py @@ -23,9 +23,9 @@ async def test_simple_headers(self): sub = await nc.subscribe("foo") await nc.flush() await nc.publish( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world-1' + "foo", b"hello world", headers={ + "foo": "bar", + "hello": "world-1" } ) @@ -33,8 +33,8 @@ async def test_simple_headers(self): self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 2) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world-1') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world-1") await nc.close() @@ -44,24 +44,24 @@ async def test_request_with_headers(self): async def service(msg): # Add another header - msg.headers['quux'] = 'quuz' - await msg.respond(b'OK!') + msg.headers["quux"] = "quuz" + await msg.respond(b"OK!") await nc.subscribe("foo", cb=service) await nc.flush() msg = await nc.request( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world' + "foo", b"hello world", headers={ + "foo": "bar", + "hello": "world" } ) self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 3) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world') - self.assertEqual(msg.headers['quux'], 'quuz') - self.assertEqual(msg.data, b'OK!') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world") + self.assertEqual(msg.headers["quux"], "quuz") + self.assertEqual(msg.data, b"OK!") await nc.close() @@ -71,38 +71,39 @@ async def test_empty_headers(self): sub = await nc.subscribe("foo") await nc.flush() - await nc.publish("foo", b'hello world', headers={'': ''}) + await nc.publish("foo", b"hello world", headers={"": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key - await nc.publish("foo", b'hello world', headers={' ': ''}) + await nc.publish("foo", b"hello world", headers={" ": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key await nc.publish( - "foo", b'hello world', headers={'': ' '} + "foo", b"hello world", headers={"": " "} ) msg = await sub.next_msg() self.assertTrue(msg.headers == None) hdrs = { - 'timestamp': '2022-06-15T19:08:14.639020', - 'type': 'rpc', - 'command': 'publish_state', - 'trace_id': '', - 'span_id': '' + "timestamp": "2022-06-15T19:08:14.639020", + "type": "rpc", + "command": "publish_state", + "trace_id": "", + "span_id": "", } - await nc.publish("foo", b'Hello from Python!', headers=hdrs) + await nc.publish("foo", b"Hello from Python!", headers=hdrs) msg = await sub.next_msg() self.assertEqual(msg.headers, hdrs) await nc.close() -if __name__ == '__main__': +if __name__ == "__main__": import sys + runner = unittest.TextTestRunner(stream=sys.stdout) unittest.main(verbosity=2, exit=False, testRunner=runner) diff --git a/tests/test_client_websocket.py b/tests/test_client_websocket.py index 097ce0f7..132b0db6 100644 --- a/tests/test_client_websocket.py +++ b/tests/test_client_websocket.py @@ -16,6 +16,7 @@ try: import aiohttp + aiohttp_installed = True except ModuleNotFoundError: aiohttp_installed = False @@ -33,9 +34,9 @@ async def test_simple_headers(self): sub = await nc.subscribe("foo") await nc.flush() await nc.publish( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world-1' + "foo", b"hello world", headers={ + "foo": "bar", + "hello": "world-1" } ) @@ -43,8 +44,8 @@ async def test_simple_headers(self): self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 2) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world-1') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world-1") await nc.close() @@ -57,24 +58,24 @@ async def test_request_with_headers(self): async def service(msg): # Add another header - msg.headers['quux'] = 'quuz' - await msg.respond(b'OK!') + msg.headers["quux"] = "quuz" + await msg.respond(b"OK!") await nc.subscribe("foo", cb=service) await nc.flush() msg = await nc.request( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world' + "foo", b"hello world", headers={ + "foo": "bar", + "hello": "world" } ) self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 3) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world') - self.assertEqual(msg.headers['quux'], 'quuz') - self.assertEqual(msg.data, b'OK!') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world") + self.assertEqual(msg.headers["quux"], "quuz") + self.assertEqual(msg.data, b"OK!") await nc.close() @@ -87,31 +88,31 @@ async def test_empty_headers(self): sub = await nc.subscribe("foo") await nc.flush() - await nc.publish("foo", b'hello world', headers={'': ''}) + await nc.publish("foo", b"hello world", headers={"": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key - await nc.publish("foo", b'hello world', headers={' ': ''}) + await nc.publish("foo", b"hello world", headers={" ": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key await nc.publish( - "foo", b'hello world', headers={'': ' '} + "foo", b"hello world", headers={"": " "} ) msg = await sub.next_msg() self.assertTrue(msg.headers == None) hdrs = { - 'timestamp': '2022-06-15T19:08:14.639020', - 'type': 'rpc', - 'command': 'publish_state', - 'trace_id': '', - 'span_id': '' + "timestamp": "2022-06-15T19:08:14.639020", + "type": "rpc", + "command": "publish_state", + "trace_id": "", + "span_id": "", } - await nc.publish("foo", b'Hello from Python!', headers=hdrs) + await nc.publish("foo", b"Hello from Python!", headers=hdrs) msg = await sub.next_msg() self.assertEqual(msg.headers, hdrs) @@ -136,16 +137,16 @@ async def reconnected_cb(): sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') + self.assertEqual(msg.data, b"First") - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. await asyncio.get_running_loop().run_in_executor( @@ -158,12 +159,12 @@ async def bar_cb(msg): await asyncio.wait_for(reconnected, 2) # Get another message. - await nc.publish("foo", b'Second') + await nc.publish("foo", b"Second") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'Second') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"Second") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") await nc.close() @@ -187,15 +188,15 @@ async def reconnected_cb(): sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"First") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. await asyncio.get_running_loop().run_in_executor( @@ -218,13 +219,13 @@ async def test_pub_sub(self): sub = await nc.subscribe("foo") await nc.flush() - await nc.publish("foo", b'hello world', headers={'foo': 'bar'}) + await nc.publish("foo", b"hello world", headers={"foo": "bar"}) msg = await sub.next_msg() self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 1) - self.assertEqual(msg.headers['foo'], 'bar') + self.assertEqual(msg.headers["foo"], "bar") await nc.close() @@ -248,16 +249,16 @@ async def reconnected_cb(): sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') + self.assertEqual(msg.data, b"First") - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. await asyncio.get_running_loop().run_in_executor( @@ -270,12 +271,12 @@ async def bar_cb(msg): await asyncio.wait_for(reconnected, 2) # Get another message. - await nc.publish("foo", b'Second') + await nc.publish("foo", b"Second") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'Second') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"Second") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") await nc.close() @@ -300,15 +301,15 @@ async def reconnected_cb(): sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"First") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. await asyncio.get_running_loop().run_in_executor( @@ -320,7 +321,8 @@ async def bar_cb(msg): await nc.close() -if __name__ == '__main__': +if __name__ == "__main__": import sys + runner = unittest.TextTestRunner(stream=sys.stdout) unittest.main(verbosity=2, exit=False, testRunner=runner) diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index ef6b1229..dfe715e3 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -33,6 +33,7 @@ @skipIf("NATS_URL" not in os.environ, "NATS_URL not set in environment") class CompatibilityTest(TestCase): + def setUp(self): self.loop = asyncio.new_event_loop() @@ -42,12 +43,15 @@ def tearDown(self): async def validate_test_result(self, sub: Subscription): try: msg = await sub.next_msg(timeout=5) - self.assertNotIn("fail", msg.subject, f"Test step failed: {msg.subject}") + self.assertNotIn( + "fail", msg.subject, f"Test step failed: {msg.subject}" + ) except asyncio.TimeoutError: self.fail("Timeout waiting for test result") @async_long_test async def test_service_compatibility(self): + @dataclass class TestGroupConfig: name: str @@ -55,7 +59,9 @@ class TestGroupConfig: @classmethod def from_dict(cls, data: Dict[str, Any]) -> TestGroupConfig: - return cls(name=data["name"], queue_group=data.get("queue_group")) + return cls( + name=data["name"], queue_group=data.get("queue_group") + ) @dataclass class TestEndpointConfig: diff --git a/tests/test_js.py b/tests/test_js.py index f6d632be..720c9819 100644 --- a/tests/test_js.py +++ b/tests/test_js.py @@ -39,20 +39,20 @@ async def test_publish(self): js = nc.jetstream() with pytest.raises(NoStreamResponseError): - await js.publish("foo", b'bar') + await js.publish("foo", b"bar") await js.add_stream(name="QUUX", subjects=["quux"]) - ack = await js.publish("quux", b'bar:1', stream="QUUX") + ack = await js.publish("quux", b"bar:1", stream="QUUX") assert ack.stream == "QUUX" assert ack.seq == 1 - ack = await js.publish("quux", b'bar:2') + ack = await js.publish("quux", b"bar:2") assert ack.stream == "QUUX" assert ack.seq == 2 with pytest.raises(BadRequestError) as err: - await js.publish("quux", b'bar', stream="BAR") + await js.publish("quux", b"bar", stream="BAR") assert err.value.err_code == 10060 await nc.close() @@ -64,20 +64,20 @@ async def test_publish_verbose(self): js = nc.jetstream() with pytest.raises(NoStreamResponseError): - await js.publish("foo", b'bar') + await js.publish("foo", b"bar") await js.add_stream(name="QUUX", subjects=["quux"]) - ack = await js.publish("quux", b'bar:1', stream="QUUX") + ack = await js.publish("quux", b"bar:1", stream="QUUX") assert ack.stream == "QUUX" assert ack.seq == 1 - ack = await js.publish("quux", b'bar:2') + ack = await js.publish("quux", b"bar:2") assert ack.stream == "QUUX" assert ack.seq == 2 with pytest.raises(BadRequestError) as err: - await js.publish("quux", b'bar', stream="BAR") + await js.publish("quux", b"bar", stream="BAR") assert err.value.err_code == 10060 await nc.close() @@ -92,7 +92,7 @@ async def test_publish_async(self): await js.publish_async_completed() self.assertEqual(js.publish_async_pending(), 0) - future = await js.publish_async("foo", b'bar') + future = await js.publish_async("foo", b"bar") self.assertEqual(js.publish_async_pending(), 1) with pytest.raises(NoStreamResponseError): await future @@ -103,7 +103,7 @@ async def test_publish_async(self): await js.add_stream(name="QUUX", subjects=["quux"]) futures = [ - await js.publish_async("quux", b'bar:1') for i in range(0, 100) + await js.publish_async("quux", b"bar:1") for i in range(0, 100) ] await js.publish_async_completed() @@ -116,18 +116,25 @@ async def test_publish_async(self): self.assertEqual(result.seq, seq) with pytest.raises(TooManyStalledMsgsError): - publishes = [js.publish_async("quux", b'bar:1', wait_stall=0.0) for i in range(0, 100)] + publishes = [ + js.publish_async("quux", b"bar:1", wait_stall=0.0) + for i in range(0, 100) + ] futures = await asyncio.gather(*publishes) results = await asyncio.gather(*futures) self.assertEqual(len(results), 100) - publishes = [js.publish_async("quux", b'bar:1', wait_stall=1.0) for i in range(0, 1000)] + publishes = [ + js.publish_async("quux", b"bar:1", wait_stall=1.0) + for i in range(0, 1000) + ] futures = await asyncio.gather(*publishes) results = await asyncio.gather(*futures) self.assertEqual(len(results), 1000) await nc.close() + class PullSubscribeTest(SingleJetStreamServerTestCase): @async_test @@ -139,19 +146,19 @@ async def test_auto_create_consumer(self): await js.add_stream(name="TEST2", subjects=["a1", "a2", "a3", "a4"]) for i in range(1, 10): - await js.publish("a1", f'a1:{i}'.encode()) + await js.publish("a1", f"a1:{i}".encode()) # Should use a2 as the filter subject and receive a subset. sub = await js.pull_subscribe("a2", "auto") - await js.publish("a2", b'one') + await js.publish("a2", b"one") for i in range(10, 20): - await js.publish("a3", f'a3:{i}'.encode()) + await js.publish("a3", f"a3:{i}".encode()) msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'one' + assert msg.data == b"one" # Getting another message should timeout for the a2 subject. with pytest.raises(TimeoutError): @@ -164,7 +171,7 @@ async def test_auto_create_consumer(self): msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'one' + assert msg.data == b"one" info = await js.consumer_info("TEST2", "auto2") assert info.config.max_waiting == 10 @@ -173,14 +180,14 @@ async def test_auto_create_consumer(self): msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'a3:10' + assert msg.data == b"a3:10" # Getting all messages from stream. sub = await js.pull_subscribe("", "all") msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'a1:1' + assert msg.data == b"a1:1" for _ in range(2, 10): msgs = await sub.fetch(1) @@ -191,13 +198,13 @@ async def test_auto_create_consumer(self): msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'one' + assert msg.data == b"one" # subject a3 msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'a3:10' + assert msg.data == b"a3:10" await nc.close() await asyncio.sleep(1) @@ -211,7 +218,7 @@ async def test_fetch_one(self): await js.add_stream(name="TEST1", subjects=["foo.1", "bar"]) - ack = await js.publish("foo.1", f'Hello from NATS!'.encode()) + ack = await js.publish("foo.1", f"Hello from NATS!".encode()) assert ack.stream == "TEST1" assert ack.seq == 1 @@ -233,7 +240,7 @@ async def test_fetch_one(self): for i in range(0, 10): await js.publish( - "foo.1", f"i:{i}".encode(), headers={'hello': 'world'} + "foo.1", f"i:{i}".encode(), headers={"hello": "world"} ) # nak @@ -241,7 +248,7 @@ async def test_fetch_one(self): msg = msgs[0] info = await js.consumer_info("TEST1", "dur", timeout=1) - assert msg.header == {'hello': 'world'} + assert msg.header == {"hello": "world"} await msg.nak() @@ -276,7 +283,7 @@ async def test_fetch_one_wait_forever(self): await js.add_stream(name="TEST111", subjects=["foo.111"]) - ack = await js.publish("foo.111", f'Hello from NATS!'.encode()) + ack = await js.publish("foo.111", f"Hello from NATS!".encode()) assert ack.stream == "TEST111" assert ack.seq == 1 @@ -305,7 +312,7 @@ async def f(): await asyncio.sleep(1) assert received is False await asyncio.sleep(1) - await js.publish("foo.111", b'Hello from NATS!') + await js.publish("foo.111", b"Hello from NATS!") await task assert received @@ -324,9 +331,9 @@ async def test_add_pull_consumer_via_jsm(self): max_deliver=20, max_waiting=512, max_ack_pending=1024, - filter_subject="events.a" + filter_subject="events.a", ) - await js.publish("events.a", b'hello world') + await js.publish("events.a", b"hello world") sub = await js.pull_subscribe_bind("a", stream="events") msgs = await sub.fetch(1) for msg in msgs: @@ -341,13 +348,13 @@ async def test_fetch_n(self): server_version = nc.connected_server_version if server_version.major == 2 and server_version.minor < 9: - pytest.skip('needs to run at least on v2.9.0') + pytest.skip("needs to run at least on v2.9.0") js = nc.jetstream() await js.add_stream(name="TESTN", subjects=["a", "b", "c"]) for i in range(0, 10): - await js.publish("a", f'i:{i}'.encode()) + await js.publish("a", f"i:{i}".encode()) sub = await js.pull_subscribe( "a", @@ -366,7 +373,7 @@ async def test_fetch_n(self): i = 0 for msg in msgs: - assert msg.data == f'i:{i}'.encode() + assert msg.data == f"i:{i}".encode() await msg.ack() i += 1 info = await sub.consumer_info() @@ -381,7 +388,7 @@ async def test_fetch_n(self): i = 5 for msg in msgs: - assert msg.data == f'i:{i}'.encode() + assert msg.data == f"i:{i}".encode() await msg.ack_sync() i += 1 @@ -399,14 +406,14 @@ async def test_fetch_n(self): # ---------- # 0 pending # 1 ack pending - await js.publish("a", b'i:11') + await js.publish("a", b"i:11") msgs = await sub.fetch(2, timeout=0.5) # Leave this message unacked. msg = msgs[0] unacked_msg = msg - assert msg.data == b'i:11' + assert msg.data == b"i:11" info = await sub.consumer_info() assert info.num_waiting < 2 assert info.num_pending == 0 @@ -419,12 +426,12 @@ async def test_fetch_n(self): # 1 extra from before but request has expired so does not count. # +1 ack pending since previous message not acked. # +1 pending to be consumed. - await js.publish("a", b'i:12') + await js.publish("a", b"i:12") # Inspect the internal buffer which should be a 408 at this point. try: msg = await sub._sub.next_msg(timeout=0.5) - assert msg.headers['Status'] == '408' + assert msg.headers["Status"] == "408" except (nats.errors.TimeoutError, TypeError): pass @@ -436,8 +443,8 @@ async def test_fetch_n(self): # Start background task that gathers messages. fut = asyncio.create_task(sub.fetch(3, timeout=2)) await asyncio.sleep(0.5) - await js.publish("a", b'i:13') - await js.publish("a", b'i:14') + await js.publish("a", b"i:13") + await js.publish("a", b"i:14") # It should receive the first one that is available already only. msgs = await fut @@ -467,10 +474,10 @@ async def test_fetch_n(self): # No messages at this point. msgs = await sub.fetch(1, timeout=0.5) - self.assertEqual(msgs[0].data, b'i:13') + self.assertEqual(msgs[0].data, b"i:13") msgs = await sub.fetch(1, timeout=0.5) - self.assertEqual(msgs[0].data, b'i:14') + self.assertEqual(msgs[0].data, b"i:14") with pytest.raises(TimeoutError): await sub.fetch(1, timeout=0.5) @@ -517,11 +524,11 @@ async def test_fetch_max_waiting_fetch_one(self): # assert info.num_waiting == 0 for i in range(0, 10): - await js.publish("max", b'foo') + await js.publish("max", b"foo") async def pub(): while True: - await js.publish("max", b'foo') + await js.publish("max", b"foo") await asyncio.sleep(0) producer = asyncio.create_task(pub()) @@ -597,7 +604,7 @@ async def test_fetch_concurrent(self): async def go_publish(): i = 0 while True: - payload = f'{i}'.encode() + payload = f"{i}".encode() await js.publish("a", payload) i += 1 await asyncio.sleep(0.01) @@ -653,7 +660,7 @@ async def test_fetch_headers(self): js = nc.jetstream() await js.add_stream(name="test-nats", subjects=["test.nats.1"]) - await js.publish("test.nats.1", b'first_msg', headers={'': ''}) + await js.publish("test.nats.1", b"first_msg", headers={"": ""}) sub = await js.pull_subscribe("test.nats.1", "durable") msgs = await sub.fetch(1) assert msgs[0].header == None @@ -664,11 +671,11 @@ async def test_fetch_headers(self): # NOTE: Headers with empty spaces are ignored. await js.publish( "test.nats.1", - b'second_msg', + b"second_msg", headers={ - ' AAA AAA AAA ': ' ', - ' B B B ': ' ' - } + " AAA AAA AAA ": " ", + " B B B ": " ", + }, ) msgs = await sub.fetch(1) assert msgs[0].header == None @@ -679,20 +686,20 @@ async def test_fetch_headers(self): # NOTE: As soon as there is a message with empty spaces are ignored. await js.publish( "test.nats.1", - b'third_msg', + b"third_msg", headers={ - ' AAA-AAA-AAA ': ' a ', - ' AAA-BBB-AAA ': ' ', - ' B B B ': ' a ' - } + " AAA-AAA-AAA ": " a ", + " AAA-BBB-AAA ": " ", + " B B B ": " a ", + }, ) msgs = await sub.fetch(1) - assert msgs[0].header['AAA-AAA-AAA'] == 'a' - assert msgs[0].header['AAA-BBB-AAA'] == '' + assert msgs[0].header["AAA-AAA-AAA"] == "a" + assert msgs[0].header["AAA-BBB-AAA"] == "" msg = await js.get_msg("test-nats", 3) - assert msg.header['AAA-AAA-AAA'] == 'a' - assert msg.header['AAA-BBB-AAA'] == '' + assert msg.header["AAA-AAA-AAA"] == "a" + assert msg.header["AAA-BBB-AAA"] == "" if not parse_email: await nc.close() @@ -700,15 +707,15 @@ async def test_fetch_headers(self): await js.publish( "test.nats.1", - b'third_msg', + b"third_msg", headers={ - ' AAA AAA AAA ': ' a ', - ' AAA-BBB-AAA ': ' b ', - ' B B B ': ' a ' - } + " AAA AAA AAA ": " a ", + " AAA-BBB-AAA ": " b ", + " B B B ": " a ", + }, ) msgs = await sub.fetch(1) - assert msgs[0].header == {'AAA-BBB-AAA': 'b'} + assert msgs[0].header == {"AAA-BBB-AAA": "b"} msg = await js.get_msg("test-nats", 4) assert msg.header == None @@ -730,7 +737,7 @@ async def error_cb(err): await js.add_stream(name="TEST2", subjects=["a1", "a2", "a3", "a4"]) for i in range(1, 10): - await js.publish("a1", f'a1:{i}'.encode()) + await js.publish("a1", f"a1:{i}".encode()) # Shorter msgs limit, and disable bytes limit to not get slow consumers. sub = await js.pull_subscribe( @@ -740,7 +747,7 @@ async def error_cb(err): pending_bytes_limit=-1, ) for i in range(0, 100): - await js.publish("a3", b'test') + await js.publish("a3", b"test") # Internal buffer will drop some of the messages due to reaching limit. msgs = await sub.fetch(100, timeout=1) @@ -768,7 +775,7 @@ async def error_cb(err): assert sub.pending_bytes == 0 # Consumer has a single message pending but none in buffer. - await js.publish("a3", b'last message') + await js.publish("a3", b"last message") info = await sub.consumer_info() assert info.num_pending == 1 assert sub.pending_msgs == 0 @@ -792,6 +799,7 @@ async def test_fetch_cancelled_errors_raised(self): pytest.skip("skip since cannot use AsyncMock") import tracemalloc + tracemalloc.start() nc = NATS() @@ -807,7 +815,7 @@ async def test_fetch_cancelled_errors_raised(self): max_deliver=20, max_waiting=512, max_ack_pending=1024, - filter_subject="test.a" + filter_subject="test.a", ) sub = await js.pull_subscribe("test.a", "test", stream="test") @@ -818,7 +826,8 @@ async def test_fetch_cancelled_errors_raised(self): with unittest.mock.patch( "asyncio.wait_for", unittest.mock.AsyncMock(side_effect=asyncio.CancelledError - )): + ), + ): await sub.fetch(batch=1, timeout=0.1) await nc.close() @@ -833,7 +842,7 @@ async def test_ephemeral_pull_subscribe(self): await js.add_stream(name=subject, subjects=["a1", "a2", "a3", "a4"]) for i in range(1, 10): - await js.publish("a1", f'a1:{i}'.encode()) + await js.publish("a1", f"a1:{i}".encode()) with pytest.raises(NotFoundError): await js.pull_subscribe("a0") @@ -856,10 +865,10 @@ async def test_consumer_with_multiple_filters(self): # Create stream. await jsm.add_stream(name="ctests", subjects=["a", "b", "c.>"]) - await js.publish("a", b'A') - await js.publish("b", b'B') - await js.publish("c.d", b'CD') - await js.publish("c.d.e", b'CDE') + await js.publish("a", b"A") + await js.publish("b", b"B") + await js.publish("c.d", b"CD") + await js.publish("c.d.e", b"CDE") # Create ephemeral pull consumer with a name. stream_name = "ctests" @@ -875,17 +884,17 @@ async def test_consumer_with_multiple_filters(self): sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'A' + assert msgs[0].data == b"A" ok = await msgs[0].ack_sync() assert ok msgs = await sub.fetch(1) - assert msgs[0].data == b'B' + assert msgs[0].data == b"B" ok = await msgs[0].ack_sync() assert ok msgs = await sub.fetch(1) - assert msgs[0].data == b'CDE' + assert msgs[0].data == b"CDE" ok = await msgs[0].ack_sync() assert ok @@ -908,7 +917,7 @@ async def test_add_consumer_with_backoff(self): filter_subject="events.>", ) for i in range(0, 3): - await js.publish("events.%d" % i, b'i:%d' % i) + await js.publish("events.%d" % i, b"i:%d" % i) sub = await js.pull_subscribe_bind("a", stream="events") events_prefix = "$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES." @@ -927,7 +936,7 @@ async def cb(msg): try: msgs = await sub.fetch(1, timeout=5) for msg in msgs: - if msg.subject == 'events.0': + if msg.subject == "events.0": received.append((time.monotonic(), msg)) except TimeoutError as err: # There should be no timeout as redeliveries should happen faster. @@ -977,7 +986,7 @@ async def test_fetch_heartbeats(self): await sub.fetch(1, timeout=1, heartbeat=0.1) for i in range(0, 15): - await js.publish("events.%d" % i, b'i:%d' % i) + await js.publish("events.%d" % i, b"i:%d" % i) # Fetch(n) msgs = await sub.fetch(5, timeout=5, heartbeat=0.1) @@ -1017,7 +1026,7 @@ async def test_fetch_heartbeats(self): with pytest.raises(nats.js.errors.APIError) as err: await sub.fetch(1, timeout=1, heartbeat=0.5) - assert err.value.description == 'Bad Request - heartbeat value too large' + assert err.value.description == "Bad Request - heartbeat value too large" # Example of catching fetch timeout instead first. got_fetch_timeout = False @@ -1090,7 +1099,7 @@ async def test_stream_management(self): # Send messages producer = nc.jetstream() - ack = await producer.publish('world', b'Hello world!') + ack = await producer.publish("world", b"Hello world!") assert ack.stream == "hello" assert ack.seq == 1 @@ -1102,7 +1111,7 @@ async def test_stream_management(self): stream_config.subjects.append("extra") updated_stream = await jsm.update_stream(stream_config) assert updated_stream.config.subjects == [ - 'hello', 'world', 'hello.>', 'extra' + "hello", "world", "hello.>", "extra" ] # Purge Stream @@ -1186,24 +1195,24 @@ async def test_jsm_get_delete_msg(self): # Create stream stream = await jsm.add_stream(name="foo", subjects=["foo.>"]) - await js.publish("foo.a.1", b'Hello', headers={'foo': 'bar'}) - await js.publish("foo.b.1", b'World') - await js.publish("foo.c.1", b'!!!') + await js.publish("foo.a.1", b"Hello", headers={"foo": "bar"}) + await js.publish("foo.b.1", b"World") + await js.publish("foo.c.1", b"!!!") # GetMsg msg = await jsm.get_msg("foo", 2) - assert msg.subject == 'foo.b.1' - assert msg.data == b'World' + assert msg.subject == "foo.b.1" + assert msg.data == b"World" msg = await jsm.get_msg("foo", 3) - assert msg.subject == 'foo.c.1' - assert msg.data == b'!!!' + assert msg.subject == "foo.c.1" + assert msg.data == b"!!!" msg = await jsm.get_msg("foo", 1) - assert msg.subject == 'foo.a.1' - assert msg.data == b'Hello' + assert msg.subject == "foo.a.1" + assert msg.data == b"Hello" assert msg.headers["foo"] == "bar" - assert msg.hdrs == 'TkFUUy8xLjANCmZvbzogYmFyDQoNCg==' + assert msg.hdrs == "TkFUUy8xLjANCmZvbzogYmFyDQoNCg==" with pytest.raises(BadRequestError): await jsm.get_msg("foo", 0) @@ -1307,7 +1316,7 @@ async def test_number_of_consumer_replicas(self): js = nc.jetstream() await js.add_stream(name="TESTREPLICAS", subjects=["test.replicas"]) for i in range(0, 10): - await js.publish("test.replicas", f'{i}'.encode()) + await js.publish("test.replicas", f"{i}".encode()) # Create consumer config = nats.js.api.ConsumerConfig( @@ -1330,10 +1339,10 @@ async def test_consumer_with_name(self): # Create stream. await jsm.add_stream(name="ctests", subjects=["a", "b", "c.>"]) - await js.publish("a", b'hello world!') - await js.publish("b", b'hello world!!') - await js.publish("c.d", b'hello world!!!') - await js.publish("c.d.e", b'hello world!!!!') + await js.publish("a", b"hello world!") + await js.publish("b", b"hello world!!") + await js.publish("c.d", b"hello world!!!") + await js.publish("c.d.e", b"hello world!!!!") # Create ephemeral pull consumer with a name. stream_name = "ctests" @@ -1346,16 +1355,16 @@ async def test_consumer_with_name(self): assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.ephemeral' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.ephemeral" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.ephemeral' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.ephemeral" # Create durable pull consumer with a name. consumer_name = "durable" @@ -1367,15 +1376,15 @@ async def test_consumer_with_name(self): ) assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.durable' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.durable" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.durable' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.durable" # Create durable pull consumer with a name and a filter_subject consumer_name = "durable2" @@ -1388,15 +1397,15 @@ async def test_consumer_with_name(self): ) assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.durable2.b' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.durable2.b" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!!' + assert msgs[0].data == b"hello world!!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.durable2' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.durable2" # Create durable pull consumer with a name and a filter_subject consumer_name = "durable3" @@ -1409,15 +1418,15 @@ async def test_consumer_with_name(self): ) assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.durable3' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.durable3" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.durable3' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.durable3" # name and durable must match if both present. with pytest.raises(BadRequestError) as err: @@ -1428,7 +1437,10 @@ async def test_consumer_with_name(self): ack_policy="explicit", ) assert err.value.err_code == 10017 - assert err.value.description == 'consumer name in subject does not match durable name in request' + assert ( + err.value.description == + "consumer name in subject does not match durable name in request" + ) # Create ephemeral pull consumer with a name and inactive threshold. stream_name = "ctests" @@ -1443,7 +1455,7 @@ async def test_consumer_with_name(self): sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok @@ -1463,16 +1475,16 @@ async def test_jsm_stream_info_options(self): stream = await jsm.add_stream(name="foo", subjects=["foo.>"]) for i in range(0, 5): - await js.publish("foo.%d" % i, b'A') + await js.publish("foo.%d" % i, b"A") si = await jsm.stream_info("foo", subjects_filter=">") assert si.state.messages == 5 assert si.state.subjects == { - 'foo.0': 1, - 'foo.1': 1, - 'foo.2': 1, - 'foo.3': 1, - 'foo.4': 1 + "foo.0": 1, + "foo.1": 1, + "foo.2": 1, + "foo.3": 1, + "foo.4": 1, } # When nothing matches streams subject will be empty. @@ -1524,7 +1536,7 @@ async def cb3(msg): subs.append(sub3) for i in range(100): - await js.publish("quux", f'Hello World {i}'.encode()) + await js.publish("quux", f"Hello World {i}".encode()) delivered = [a, b, c] for _ in range(5): @@ -1545,7 +1557,7 @@ async def cb3(msg): # Confirm that no more messages are received. for _ in range(200, 210): - await js.publish("quux", f'Hello World {i}'.encode()) + await js.publish("quux", f"Hello World {i}".encode()) with pytest.raises(BadSubscriptionError): await sub1.unsubscribe() @@ -1571,7 +1583,7 @@ async def cb2(msg): b.append(msg) for i in range(10): - await js.publish("pbound", f'Hello World {i}'.encode()) + await js.publish("pbound", f"Hello World {i}".encode()) # First subscriber will create. await js.subscribe("pbound", cb=cb1, durable="singleton") @@ -1602,7 +1614,7 @@ async def cb2(msg): # Create a sync subscriber now. sub2 = await js.subscribe("pbound", durable="two") msg = await sub2.next_msg() - assert msg.data == b'Hello World 0' + assert msg.data == b"Hello World 0" assert msg.metadata.sequence.stream == 1 assert msg.metadata.sequence.consumer == 1 assert sub2.pending_msgs == 9 @@ -1618,7 +1630,7 @@ async def test_ephemeral_subscribe(self): await js.add_stream(name=subject, subjects=[subject]) for i in range(10): - await js.publish(subject, f'Hello World {i}'.encode()) + await js.publish(subject, f"Hello World {i}".encode()) # First subscriber will create. sub1 = await js.subscribe(subject) @@ -1677,7 +1689,7 @@ async def test_subscribe_bind(self): config=consumer_info.config, ) for i in range(10): - await js.publish(subject_name, f'Hello World {i}'.encode()) + await js.publish(subject_name, f"Hello World {i}".encode()) msgs = [] for i in range(0, 10): @@ -1739,7 +1751,7 @@ async def cb3(msg): subs.append(sub3) for i in range(100): - await js.publish("quux", f'Hello World {i}'.encode()) + await js.publish("quux", f"Hello World {i}".encode()) # Only 3rd group should have ran into slow consumer errors. await asyncio.sleep(1) @@ -1761,11 +1773,18 @@ async def test_unsubscribe_coroutines(self): coroutines_before = len(asyncio.all_tasks()) a = [] + async def cb(msg): a.append(msg) # Create a PushSubscription - sub = await js.subscribe("quux", "wg", cb=cb, ordered_consumer=True, config=nats.js.api.ConsumerConfig(idle_heartbeat=5)) + sub = await js.subscribe( + "quux", + "wg", + cb=cb, + ordered_consumer=True, + config=nats.js.api.ConsumerConfig(idle_heartbeat=5), + ) # breakpoint() @@ -1788,6 +1807,7 @@ async def cb(msg): self.assertEqual(coroutines_before, coroutines_after_unsubscribe) self.assertNotEqual(coroutines_before, coroutines_after_subscribe) + class AckPolicyTest(SingleJetStreamServerTestCase): @async_test @@ -1849,7 +1869,7 @@ async def test_double_acking_pull_subscribe(self): js = nc.jetstream() await js.add_stream(name="TESTACKS", subjects=["test"]) for i in range(0, 10): - await js.publish("test", f'{i}'.encode()) + await js.publish("test", f"{i}".encode()) # Pull Subscriber psub = await js.pull_subscribe("test", "durable") @@ -1904,7 +1924,7 @@ async def error_handler(e): js = nc.jetstream() await js.add_stream(name="TESTACKS", subjects=["test"]) for i in range(0, 10): - await js.publish("test", f'{i}'.encode()) + await js.publish("test", f"{i}".encode()) async def ocb(msg): # Ack the first one only. @@ -1941,7 +1961,7 @@ async def test_nak_delay(self): js = nc.jetstream() await js.add_stream(name=stream, subjects=[subject]) - await js.publish(subject, b'message') + await js.publish(subject, b"message") psub = await js.pull_subscribe(subject, "durable") msgs = await psub.fetch(1) @@ -1969,106 +1989,111 @@ async def f(): assert task.done() assert received + class DiscardPolicyTest(SingleJetStreamServerTestCase): - @async_test - async def test_with_discard_new_and_discard_new_per_subject_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO0" - config = nats.js.api.StreamConfig( - name=stream_name, - discard=nats.js.api.DiscardPolicy.NEW, - discard_new_per_subject=True, - max_msgs_per_subject=100 - ) - - await js.add_stream(config) - info = await js.stream_info(stream_name) - self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) - self.assertEqual(info.config.discard_new_per_subject, True) - - # Close the NATS connection after the test - await nc.close() - - @async_test - async def test_with_discard_new_and_discard_new_per_subject_not_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO1" - config = nats.js.api.StreamConfig( - name=stream_name, - discard=nats.js.api.DiscardPolicy.NEW, - discard_new_per_subject=False, - max_msgs_per_subject=100 - ) - - await js.add_stream(config) - info = await js.stream_info(stream_name) - self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) - self.assertEqual(info.config.discard_new_per_subject, False) - - await nc.close() - - @async_test - async def test_with_discard_old_and_discard_new_per_subject_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO2" - config = nats.js.api.StreamConfig( - name=stream_name, - discard='DiscardOld', - discard_new_per_subject=True, - max_msgs_per_subject=100 - ) - - with self.assertRaises(APIError): - await js.add_stream(config) - - # Close the NATS connection after the test - await nc.close() - - @async_test - async def test_with_discard_old_and_discard_new_per_subject_not_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO3" - config = nats.js.api.StreamConfig( - name=stream_name, - discard='DiscardOld', - discard_new_per_subject=True, - max_msgs_per_subject=100 - ) - - with self.assertRaises(APIError): - await js.add_stream(config) - - await nc.close() - - @async_test - async def test_with_discard_new_and_discard_new_per_subject_set_no_max_msgs(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO4" - config = nats.js.api.StreamConfig( - name=stream_name, - discard=nats.js.api.DiscardPolicy.NEW, - discard_new_per_subject=True - ) - - with self.assertRaises(APIError): - await js.add_stream(config) - - await nc.close() + + @async_test + async def test_with_discard_new_and_discard_new_per_subject_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO0" + config = nats.js.api.StreamConfig( + name=stream_name, + discard=nats.js.api.DiscardPolicy.NEW, + discard_new_per_subject=True, + max_msgs_per_subject=100, + ) + + await js.add_stream(config) + info = await js.stream_info(stream_name) + self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) + self.assertEqual(info.config.discard_new_per_subject, True) + + # Close the NATS connection after the test + await nc.close() + + @async_test + async def test_with_discard_new_and_discard_new_per_subject_not_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO1" + config = nats.js.api.StreamConfig( + name=stream_name, + discard=nats.js.api.DiscardPolicy.NEW, + discard_new_per_subject=False, + max_msgs_per_subject=100, + ) + + await js.add_stream(config) + info = await js.stream_info(stream_name) + self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) + self.assertEqual(info.config.discard_new_per_subject, False) + + await nc.close() + + @async_test + async def test_with_discard_old_and_discard_new_per_subject_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO2" + config = nats.js.api.StreamConfig( + name=stream_name, + discard="DiscardOld", + discard_new_per_subject=True, + max_msgs_per_subject=100, + ) + + with self.assertRaises(APIError): + await js.add_stream(config) + + # Close the NATS connection after the test + await nc.close() + + @async_test + async def test_with_discard_old_and_discard_new_per_subject_not_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO3" + config = nats.js.api.StreamConfig( + name=stream_name, + discard="DiscardOld", + discard_new_per_subject=True, + max_msgs_per_subject=100, + ) + + with self.assertRaises(APIError): + await js.add_stream(config) + + await nc.close() + + @async_test + async def test_with_discard_new_and_discard_new_per_subject_set_no_max_msgs( + self + ): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO4" + config = nats.js.api.StreamConfig( + name=stream_name, + discard=nats.js.api.DiscardPolicy.NEW, + discard_new_per_subject=True, + ) + + with self.assertRaises(APIError): + await js.add_stream(config) + + await nc.close() + class OrderedConsumerTest(SingleJetStreamServerTestCase): @@ -2092,7 +2117,10 @@ async def cb(msg): with pytest.raises(nats.js.errors.APIError) as err: sub = await js.subscribe(subject, cb=cb, flow_control=True) - assert err.value.description == "consumer with flow control also needs heartbeats" + assert ( + err.value.description == + "consumer with flow control also needs heartbeats" + ) sub = await js.subscribe( subject, cb=cb, flow_control=True, idle_heartbeat=0.5 @@ -2102,7 +2130,7 @@ async def cb(msg): async def producer(): mlen = 128 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2159,7 +2187,7 @@ async def cb(msg): async def producer(): mlen = 1024 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2172,7 +2200,7 @@ async def producer(): chunk = msg[i:i + chunksize] i += chunksize task = asyncio.create_task( - js.publish(subject, chunk, headers={'data': "true"}) + js.publish(subject, chunk, headers={"data": "true"}) ) await asyncio.sleep(0) tasks.append(task) @@ -2182,7 +2210,7 @@ async def producer(): await asyncio.gather(*tasks) assert len(msgs) == 1024 - received_payload = bytearray(b'') + received_payload = bytearray(b"") for msg in msgs: received_payload.extend(msg.data) assert len(received_payload) == expected_size @@ -2263,7 +2291,7 @@ async def cb(msg): async def producer(): mlen = 1024 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2276,7 +2304,7 @@ async def producer(): chunk = msg[i:i + chunksize] i += chunksize task = asyncio.create_task( - nc2.publish(subject, chunk, headers={'data': "true"}) + nc2.publish(subject, chunk, headers={"data": "true"}) ) tasks.append(task) @@ -2292,7 +2320,7 @@ async def producer(): await asyncio.wait_for(future, timeout=5) assert len(msgs) == 1024 - received_payload = bytearray(b'') + received_payload = bytearray(b"") for msg in msgs: received_payload.extend(msg.data) assert len(received_payload) == expected_size @@ -2330,14 +2358,14 @@ async def error_handler(e): js = nc.jetstream() js2 = nc2.jetstream() - subject = 'ORDERS' + subject = "ORDERS" await js2.add_stream(name=subject, subjects=[subject], storage="file") tasks = [] async def producer(): mlen = 10 * 1024 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2350,7 +2378,7 @@ async def producer(): chunk = msg[i:i + chunksize] i += chunksize task = asyncio.create_task( - nc2.publish(subject, chunk, headers={'data': "true"}) + nc2.publish(subject, chunk, headers={"data": "true"}) ) tasks.append(task) @@ -2367,7 +2395,7 @@ async def producer(): async def cb(msg): nonlocal i nonlocal done - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") i += 1 if i == stream.state.messages: if not done.done(): @@ -2386,7 +2414,7 @@ async def cb(msg): while i < stream.state.messages: try: msg = await sub.next_msg() - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") if i % 1000 == 0: await asyncio.sleep(0) i += 1 @@ -2411,7 +2439,7 @@ async def cb(msg): ) try: msg = await sub.next_msg() - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") if i % 1000 == 0: await asyncio.sleep(0) i += 1 @@ -2434,7 +2462,7 @@ async def cb(msg): None, self.server_pool[0].start ) - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") i += 1 if i == stream.state.messages: if not done.done(): @@ -2461,7 +2489,7 @@ async def error_handler(e): name="MY_STREAM", subjects=["test.*"], storage="memory" ) subject = "test.1" - for m in ['1', '2', '3']: + for m in ["1", "2", "3"]: await js.publish(subject=subject, payload=m.encode()) sub = await js.subscribe( @@ -2499,18 +2527,18 @@ async def error_handler(e): assert status.history == 5 assert status.ttl == 3600 - seq = await kv.put("hello", b'world') + seq = await kv.put("hello", b"world") assert seq == 1 entry = await kv.get("hello") assert "hello" == entry.key - assert b'world' == entry.value + assert b"world" == entry.value status = await kv.status() assert status.values == 1 for i in range(0, 100): - await kv.put(f"hello.{i}", b'Hello JS KV!') + await kv.put(f"hello.{i}", b"Hello JS KV!") status = await kv.status() assert status.values == 101 @@ -2524,10 +2552,10 @@ async def error_handler(e): with pytest.raises(KeyNotFoundError) as err: await kv.get("hello.1") - assert err.value.entry.key == 'hello.1' + assert err.value.entry.key == "hello.1" assert err.value.entry.revision == 102 assert err.value.entry.value == None - assert err.value.op == 'DEL' + assert err.value.op == "DEL" await kv.purge("hello.5") with pytest.raises(KeyNotFoundError) as err: @@ -2549,21 +2577,21 @@ async def error_handler(e): entry = await kv.get("hello") assert "hello" == entry.key - assert b'world' == entry.value + assert b"world" == entry.value assert 1 == entry.revision # Now get the same KV via lookup. kv = await js.key_value("TEST") entry = await kv.get("hello") assert "hello" == entry.key - assert b'world' == entry.value + assert b"world" == entry.value assert 1 == entry.revision status = await kv.status() assert status.values == 1 for i in range(100, 200): - await kv.put(f"hello.{i}", b'Hello JS KV!') + await kv.put(f"hello.{i}", b"Hello JS KV!") status = await kv.status() assert status.values == 101 @@ -2573,7 +2601,7 @@ async def error_handler(e): entry = await kv.get("hello.102") assert "hello.102" == entry.key - assert b'Hello JS KV!' == entry.value + assert b"Hello JS KV!" == entry.value assert 107 == entry.revision await nc.close() @@ -2621,13 +2649,13 @@ async def error_handler(e): si = await js.stream_info("KV_TEST") config = si.config assert config.description == "Basic KV" - assert config.subjects == ['$KV.TEST.>'] + assert config.subjects == ["$KV.TEST.>"] # Check server version for some of these. assert config.allow_rollup_hdrs == True assert config.deny_delete == True assert config.deny_purge == False - assert config.discard == 'new' + assert config.discard == "new" assert config.duplicate_window == 120.0 assert config.max_age == 3600.0 assert config.max_bytes == -1 @@ -2639,10 +2667,10 @@ async def error_handler(e): assert config.no_ack == False assert config.num_replicas == 1 assert config.placement == None - assert config.retention == 'limits' + assert config.retention == "limits" assert config.sealed == False assert config.sources == None - assert config.storage == 'file' + assert config.storage == "file" assert config.template_owner == None version = nc.connected_server_version @@ -2656,13 +2684,13 @@ async def error_handler(e): await kv.get(f"name") # Simple Put - revision = await kv.put(f"name", b'alice') + revision = await kv.put(f"name", b"alice") assert revision == 1 # Simple Get result = await kv.get(f"name") assert result.revision == 1 - assert result.value == b'alice' + assert result.value == b"alice" # Delete ok = await kv.delete(f"name") @@ -2674,7 +2702,7 @@ async def error_handler(e): await kv.get(f"name") # Recreate with different name. - revision = await kv.create("name", b'bob') + revision = await kv.create("name", b"bob") assert revision == 3 # Expect last revision to be 4 @@ -2698,36 +2726,37 @@ async def error_handler(e): assert revision == 6 # Create a different key. - revision = await kv.create("age", b'2038') + revision = await kv.create("age", b"2038") assert revision == 7 # Get current. entry = await kv.get("age") - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Update the new key. - revision = await kv.update("age", b'2039', last=revision) + revision = await kv.update("age", b"2039", last=revision) assert revision == 8 # Get latest. entry = await kv.get("age") - assert entry.value == b'2039' + assert entry.value == b"2039" assert entry.revision == 8 # Internally uses get msg API instead of get last msg. entry = await kv.get("age", revision=7) - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Getting past keys with the wrong expected subject is an error. with pytest.raises(KeyNotFoundError) as err: entry = await kv.get("age", revision=6) - assert entry.value == b'fuga' + assert entry.value == b"fuga" assert entry.revision == 6 - assert str( - err.value - ) == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + assert ( + str(err.value) == + "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + ) with pytest.raises(KeyNotFoundError) as err: await kv.get("age", revision=5) @@ -2736,19 +2765,19 @@ async def error_handler(e): await kv.get("age", revision=4) entry = await kv.get("name", revision=3) - assert entry.value == b'bob' + assert entry.value == b"bob" with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 8"): - await kv.create("age", b'1') + await kv.create("age", b"1") # Now let's delete and recreate. await kv.delete("age", last=8) - await kv.create("age", b'final') + await kv.create("age", b"final") with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 10"): - await kv.create("age", b'1') + await kv.create("age", b"1") entry = await kv.get("age") assert entry.revision == 10 @@ -2781,32 +2810,32 @@ async def error_handler(e): si = await js.stream_info("KV_TEST") config = si.config assert config.description == "Direct KV" - assert config.subjects == ['$KV.TEST.>'] - await kv.create("A", b'1') - await kv.create("B", b'2') - await kv.create("C", b'3') - await kv.create("D", b'4') - await kv.create("E", b'5') - await kv.create("F", b'6') - - await kv.put("C", b'33') - await kv.put("D", b'44') - await kv.put("C", b'333') + assert config.subjects == ["$KV.TEST.>"] + await kv.create("A", b"1") + await kv.create("B", b"2") + await kv.create("C", b"3") + await kv.create("D", b"4") + await kv.create("E", b"5") + await kv.create("F", b"6") + + await kv.put("C", b"33") + await kv.put("D", b"44") + await kv.put("C", b"333") # Check with low level msg APIs. msg = await js.get_msg("KV_TEST", seq=1, direct=True) - assert msg.data == b'1' + assert msg.data == b"1" # last by subject msg = await js.get_msg("KV_TEST", subject="$KV.TEST.C", direct=True) - assert msg.data == b'333' + assert msg.data == b"333" # next by subject msg = await js.get_msg( "KV_TEST", seq=4, next=True, subject="$KV.TEST.C", direct=True ) - assert msg.data == b'33' + assert msg.data == b"33" @async_test async def test_kv_direct(self): @@ -2829,7 +2858,7 @@ async def error_handler(e): history=5, ttl=3600, description="Explicit Direct KV", - direct=True + direct=True, ) kv = await js.key_value(bucket=bucket) status = await kv.status() @@ -2837,14 +2866,14 @@ async def error_handler(e): si = await js.stream_info("KV_TEST") config = si.config assert config.description == "Explicit Direct KV" - assert config.subjects == ['$KV.TEST.>'] + assert config.subjects == ["$KV.TEST.>"] # Check server version for some of these. assert config.allow_rollup_hdrs == True assert config.allow_direct == True assert config.deny_delete == True assert config.deny_purge == False - assert config.discard == 'new' + assert config.discard == "new" assert config.duplicate_window == 120.0 assert config.max_age == 3600.0 assert config.max_bytes == -1 @@ -2856,10 +2885,10 @@ async def error_handler(e): assert config.no_ack == False assert config.num_replicas == 1 assert config.placement == None - assert config.retention == 'limits' + assert config.retention == "limits" assert config.sealed == False assert config.sources == None - assert config.storage == 'file' + assert config.storage == "file" assert config.template_owner == None # Nothing from start @@ -2867,13 +2896,13 @@ async def error_handler(e): await kv.get(f"name") # Simple Put - revision = await kv.put(f"name", b'alice') + revision = await kv.put(f"name", b"alice") assert revision == 1 # Simple Get result = await kv.get(f"name") assert result.revision == 1 - assert result.value == b'alice' + assert result.value == b"alice" # Delete ok = await kv.delete(f"name") @@ -2885,7 +2914,7 @@ async def error_handler(e): await kv.get(f"name") # Recreate with different name. - revision = await kv.create("name", b'bob') + revision = await kv.create("name", b"bob") assert revision == 3 # Expect last revision to be 4 @@ -2909,36 +2938,37 @@ async def error_handler(e): assert revision == 6 # Create a different key. - revision = await kv.create("age", b'2038') + revision = await kv.create("age", b"2038") assert revision == 7 # Get current. entry = await kv.get("age") - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Update the new key. - revision = await kv.update("age", b'2039', last=revision) + revision = await kv.update("age", b"2039", last=revision) assert revision == 8 # Get latest. entry = await kv.get("age") - assert entry.value == b'2039' + assert entry.value == b"2039" assert entry.revision == 8 # Internally uses get msg API instead of get last msg. entry = await kv.get("age", revision=7) - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Getting past keys with the wrong expected subject is an error. with pytest.raises(KeyNotFoundError) as err: entry = await kv.get("age", revision=6) - assert entry.value == b'fuga' + assert entry.value == b"fuga" assert entry.revision == 6 - assert str( - err.value - ) == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + assert ( + str(err.value) == + "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + ) with pytest.raises(KeyNotFoundError) as err: await kv.get("age", revision=5) @@ -2947,19 +2977,19 @@ async def error_handler(e): await kv.get("age", revision=4) entry = await kv.get("name", revision=3) - assert entry.value == b'bob' + assert entry.value == b"bob" with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 8"): - await kv.create("age", b'1') + await kv.create("age", b"1") # Now let's delete and recreate. await kv.delete("age", last=8) - await kv.create("age", b'final') + await kv.create("age", b"final") with pytest.raises(KeyWrongLastSequenceError, match="nats: wrong last sequence: 10"): - await kv.create("age", b'1') + await kv.create("age", b"1") entry = await kv.get("age") assert entry.revision == 10 @@ -2967,7 +2997,10 @@ async def error_handler(e): with pytest.raises(Error) as err: await js.add_stream(name="mirror", mirror_direct=True) assert err.value.err_code == 10052 - assert err.value.description == 'stream has no mirror but does have mirror direct' + assert ( + err.value.description == + "stream has no mirror but does have mirror direct" + ) await nc.close() @@ -2993,51 +3026,51 @@ async def error_handler(e): e = await w.updates(timeout=1) assert e is None - await kv.create("name", b'alice:1') + await kv.create("name", b"alice:1") e = await w.updates() assert e.delta == 0 - assert e.key == 'name' - assert e.value == b'alice:1' + assert e.key == "name" + assert e.value == b"alice:1" assert e.revision == 1 - await kv.put("name", b'alice:2') + await kv.put("name", b"alice:2") e = await w.updates() - assert e.key == 'name' - assert e.value == b'alice:2' + assert e.key == "name" + assert e.value == b"alice:2" assert e.revision == 2 - await kv.put("name", b'alice:3') + await kv.put("name", b"alice:3") e = await w.updates() - assert e.key == 'name' - assert e.value == b'alice:3' + assert e.key == "name" + assert e.value == b"alice:3" assert e.revision == 3 - await kv.put("age", b'22') + await kv.put("age", b"22") e = await w.updates() - assert e.key == 'age' - assert e.value == b'22' + assert e.key == "age" + assert e.value == b"22" assert e.revision == 4 - await kv.put("age", b'33') + await kv.put("age", b"33") e = await w.updates() assert e.bucket == "WATCH" - assert e.key == 'age' - assert e.value == b'33' + assert e.key == "age" + assert e.value == b"33" assert e.revision == 5 await kv.delete("age") e = await w.updates() assert e.bucket == "WATCH" - assert e.key == 'age' - assert e.value == b'' + assert e.key == "age" + assert e.value == b"" assert e.revision == 6 assert e.operation == "DEL" await kv.purge("name") e = await w.updates() assert e.bucket == "WATCH" - assert e.key == 'name' - assert e.value == b'' + assert e.key == "name" + assert e.value == b"" assert e.revision == 7 assert e.operation == "PURGE" @@ -3049,13 +3082,13 @@ async def error_handler(e): await w.stop() # Now try wildcard matching and make sure we only get last value when starting. - await kv.create("new", b'hello world') - await kv.put("t.name", b'a') - await kv.put("t.name", b'b') - await kv.put("t.age", b'c') - await kv.put("t.age", b'd') - await kv.put("t.a", b'a') - await kv.put("t.b", b'b') + await kv.create("new", b"hello world") + await kv.put("t.name", b"a") + await kv.put("t.name", b"b") + await kv.put("t.age", b"c") + await kv.put("t.age", b"d") + await kv.put("t.a", b"a") + await kv.put("t.b", b"b") # Will only get last values of the matching keys. w = await kv.watch("t.*") @@ -3065,7 +3098,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 3 assert e.key == "t.name" - assert e.value == b'b' + assert e.value == b"b" assert e.revision == 10 assert e.operation == None @@ -3073,7 +3106,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 2 assert e.key == "t.age" - assert e.value == b'd' + assert e.value == b"d" assert e.revision == 12 assert e.operation == None @@ -3081,7 +3114,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 1 assert e.key == "t.a" - assert e.value == b'a' + assert e.value == b"a" assert e.revision == 13 assert e.operation == None @@ -3090,7 +3123,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 0 assert e.key == "t.b" - assert e.value == b'b' + assert e.value == b"b" assert e.revision == 14 assert e.operation == None @@ -3103,10 +3136,10 @@ async def error_handler(e): with pytest.raises(TimeoutError): await w.updates(timeout=1) - await kv.put("t.hello", b'hello world') + await kv.put("t.hello", b"hello world") e = await w.updates() assert e.delta == 0 - assert e.key == 't.hello' + assert e.key == "t.hello" assert e.revision == 15 # Default watch timeout should 5 minutes @@ -3135,14 +3168,14 @@ async def error_handler(e): status = await kv.status() for i in range(0, 50): - await kv.put(f"age", f'{i}'.encode()) + await kv.put(f"age", f"{i}".encode()) vl = await kv.history("age") assert len(vl) == 10 i = 0 for entry in vl: - assert entry.key == 'age' + assert entry.key == "age" assert entry.revision == i + 41 assert int(entry.value) == i + 40 i += 1 @@ -3166,12 +3199,12 @@ async def error_handler(e): with pytest.raises(NoKeysError): await kv.keys() - await kv.put("a", b'1') - await kv.put("b", b'2') - await kv.put("a", b'11') - await kv.put("b", b'22') - await kv.put("a", b'111') - await kv.put("b", b'222') + await kv.put("a", b"1") + await kv.put("b", b"2") + await kv.put("a", b"11") + await kv.put("b", b"22") + await kv.put("a", b"111") + await kv.put("b", b"222") keys = await kv.keys() assert len(keys) == 2 @@ -3190,10 +3223,10 @@ async def error_handler(e): with pytest.raises(NoKeysError): await kv.keys() - await kv.create("c", b'3') + await kv.create("c", b"3") keys = await kv.keys() assert len(keys) == 1 - assert 'c' in keys + assert "c" in keys await nc.close() @@ -3232,7 +3265,7 @@ async def error_handler(e): for i in range(0, 10): await kv.delete(f"key-{i}") - await kv.put(f"key-last", b'101') + await kv.put(f"key-last", b"101") await kv.purge_deletes(olderthan=-1) await asyncio.sleep(0.5) @@ -3281,9 +3314,9 @@ async def error_handler(e): history = await kv.history("bar") assert len(history) == 1 entry = history[0] - assert entry.key == 'bar' + assert entry.key == "bar" assert entry.revision == 6 - assert entry.operation == 'DEL' + assert entry.operation == "DEL" await nc.close() @@ -3299,10 +3332,10 @@ async def error_handler(e): js = nc.jetstream() async def pub(): - await js.publish("foo.A", b'1') - await js.publish("foo.C", b'1') - await js.publish("foo.B", b'1') - await js.publish("foo.C", b'2') + await js.publish("foo.A", b"1") + await js.publish("foo.C", b"1") + await js.publish("foo.B", b"1") + await js.publish("foo.C", b"2") await js.add_stream(name="foo", subjects=["foo.A", "foo.B", "foo.C"]) await pub() @@ -3313,12 +3346,12 @@ async def pub(): assert info.state.messages == 2 msgs = await sub.fetch(5, timeout=1) assert len(msgs) == 2 - assert msgs[0].subject == 'foo.B' - assert msgs[1].subject == 'foo.C' + assert msgs[0].subject == "foo.B" + assert msgs[1].subject == "foo.C" await js.publish( "foo.C", - b'3', - headers={nats.js.api.Header.EXPECTED_LAST_SUBJECT_SEQUENCE: "4"} + b"3", + headers={nats.js.api.Header.EXPECTED_LAST_SUBJECT_SEQUENCE: "4"}, ) await nc.close() @@ -3343,10 +3376,10 @@ async def error_handler(e): assert sinfo.config.republish is not None sub = await nc.subscribe("bar.>") - await kv.put("hello.world", b'Hello World!') + await kv.put("hello.world", b"Hello World!") msg = await sub.next_msg() - assert msg.data == b'Hello World!' - assert msg.headers.get('Nats-Msg-Size', None) == None + assert msg.data == b"Hello World!" + assert msg.headers.get("Nats-Msg-Size", None) == None await sub.unsubscribe() kv = await js.create_key_value( @@ -3355,15 +3388,15 @@ async def error_handler(e): src=">", dest="quux.>", headers_only=True, - ) + ), ) sub = await nc.subscribe("quux.>") - await kv.put("hello.world", b'Hello World!') + await kv.put("hello.world", b"Hello World!") msg = await sub.next_msg() - assert msg.data == b'' - assert msg.headers['Nats-Sequence'] == '1' - assert msg.headers['Nats-Msg-Size'] == '12' - assert msg.headers['Nats-Stream'] == 'KV_TEST_UPDATE_HEADERS' + assert msg.data == b"" + assert msg.headers["Nats-Sequence"] == "1" + assert msg.headers["Nats-Msg-Size"] == "12" + assert msg.headers["Nats-Stream"] == "KV_TEST_UPDATE_HEADERS" await sub.unsubscribe() await nc.close() @@ -3379,14 +3412,16 @@ async def error_handler(e): js = nc.jetstream() # Create a KV bucket for testing - kv = await js.create_key_value(bucket="TEST_LOGGING", history=5, ttl=3600) + kv = await js.create_key_value( + bucket="TEST_LOGGING", history=5, ttl=3600 + ) # Add keys to the bucket - await kv.put("hello", b'world') - await kv.put("greeting", b'hi') + await kv.put("hello", b"world") + await kv.put("greeting", b"hi") # Test with filters (fetch keys with "hello" or "greet") - filtered_keys = await kv.keys(filters=['hello', 'greet']) + filtered_keys = await kv.keys(filters=["hello", "greet"]) assert "hello" in filtered_keys assert "greeting" in filtered_keys @@ -3421,7 +3456,7 @@ async def error_handler(e): sinfo = status.stream_info assert sinfo.config.name == "OBJ_OBJS" assert sinfo.config.description == "testing" - assert sinfo.config.subjects == ['$O.OBJS.C.>', '$O.OBJS.M.>'] + assert sinfo.config.subjects == ["$O.OBJS.C.>", "$O.OBJS.M.>"] assert sinfo.config.retention == "limits" assert sinfo.config.max_consumers == -1 assert sinfo.config.max_msgs == -1 @@ -3436,7 +3471,7 @@ async def error_handler(e): assert sinfo.config.allow_direct == True assert sinfo.config.mirror_direct == False - bucketname = ''.join( + bucketname = "".join( random.SystemRandom().choice(string.ascii_letters) for _ in range(10) ) @@ -3449,13 +3484,13 @@ async def error_handler(e): assert obs._stream == f"OBJ_{bucketname}" # Simple example using bytes. - info = await obs.put("foo", b'bar') + info = await obs.put("foo", b"bar") assert info.name == "foo" assert info.size == 3 assert info.bucket == bucketname # Simple example using strings. - info = await obs.put("plaintext", 'lorem ipsum') + info = await obs.put("plaintext", "lorem ipsum") assert info.name == "plaintext" assert info.size == 11 assert info.bucket == bucketname @@ -3464,16 +3499,16 @@ async def error_handler(e): opts = api.ObjectMetaOptions(max_chunk_size=5) info = await obs.put( "filename.txt", - b'filevalue', + b"filevalue", meta=nats.js.api.ObjectMeta( description="filedescription", headers={"headername": "headerval"}, options=opts, - ) + ), ) filename = "filename.txt" - filevalue = b'filevalue' + filevalue = b"filevalue" filedesc = "filedescription" fileheaders = {"headername": "headerval"} assert info.name == filename @@ -3486,7 +3521,9 @@ async def error_handler(e): h = sha256() h.update(filevalue) h.digest() - expected_digest = f"SHA-256={base64.urlsafe_b64encode(h.digest()).decode('utf-8')}" + expected_digest = ( + f"SHA-256={base64.urlsafe_b64encode(h.digest()).decode('utf-8')}" + ) assert info.digest == expected_digest assert info.deleted == False assert info.description == filedesc @@ -3542,13 +3579,13 @@ async def error_handler(e): # Create an 8MB object. obs = await js.create_object_store(bucket="big") - ls = ''.join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) + ls = "".join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) w = io.BytesIO(ls.encode()) info = await obs.put("big", w) assert info.name == "big" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Create actual file and put it in a bucket. tmp = tempfile.NamedTemporaryFile(delete=False) @@ -3560,14 +3597,14 @@ async def error_handler(e): assert info.name == "tmp" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" with open(tmp.name) as f: info = await obs.put("tmp2", f) assert info.name == "tmp2" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # By default this reads the complete data. obr = await obs.get("tmp") @@ -3575,7 +3612,7 @@ async def error_handler(e): assert info.name == "tmp" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Using a local file. with open("pyproject.toml") as f: @@ -3592,19 +3629,25 @@ async def error_handler(e): # Write into file without getting complete data. w = tempfile.NamedTemporaryFile(delete=False) w.close() - with open(w.name, 'w') as f: + with open(w.name, "w") as f: obr = await obs.get("tmp", writeinto=f) - assert obr.data == b'' + assert obr.data == b"" assert obr.info.size == 1048609 - assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert ( + obr.info.digest == + "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" + ) w2 = tempfile.NamedTemporaryFile(delete=False) w2.close() - with open(w2.name, 'w') as f: + with open(w2.name, "w") as f: obr = await obs.get("tmp", writeinto=f.buffer) - assert obr.data == b'' + assert obr.data == b"" assert obr.info.size == 1048609 - assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert ( + obr.info.digest == + "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" + ) with open(w2.name) as f: result = f.read(-1) @@ -3625,7 +3668,7 @@ async def error_handler(e): # Create an 8MB object. obs = await js.create_object_store(bucket="sample") - ls = ''.join("A" for _ in range(0, 2 * 1024 * 1024 + 33)) + ls = "".join("A" for _ in range(0, 2 * 1024 * 1024 + 33)) w = io.BytesIO(ls.encode()) info = await obs.put("sample", w) assert info.name == "sample" @@ -3680,22 +3723,22 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_FILES", - config=nats.js.api.ObjectStoreConfig(description="multi_files", ) + config=nats.js.api.ObjectStoreConfig(description="multi_files", ), ) - await obs.put("A", b'A') - await obs.put("B", b'B') - await obs.put("C", b'C') + await obs.put("A", b"A") + await obs.put("B", b"B") + await obs.put("C", b"C") res = await obs.get("A") - assert res.data == b'A' + assert res.data == b"A" assert res.info.digest == "SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=" res = await obs.get("B") - assert res.data == b'B' + assert res.data == b"B" assert res.info.digest == "SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=" res = await obs.get("C") - assert res.data == b'C' + assert res.data == b"C" assert res.info.digest == "SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=" with open("README.md") as fp: @@ -3724,49 +3767,49 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_FILES", - config=nats.js.api.ObjectStoreConfig(description="multi_files", ) + config=nats.js.api.ObjectStoreConfig(description="multi_files", ), ) watcher = await obs.watch() e = await watcher.updates() assert e == None - await obs.put("A", b'A') - await obs.put("B", b'B') - await obs.put("C", b'C') + await obs.put("A", b"A") + await obs.put("B", b"B") + await obs.put("C", b"C") res = await obs.get("A") - assert res.data == b'A' + assert res.data == b"A" assert res.info.digest == "SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=" res = await obs.get("B") - assert res.data == b'B' + assert res.data == b"B" assert res.info.digest == "SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=" res = await obs.get("C") - assert res.data == b'C' + assert res.data == b"C" assert res.info.digest == "SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=" e = await watcher.updates() - assert e.name == 'A' - assert e.bucket == 'TEST_FILES' + assert e.name == "A" + assert e.bucket == "TEST_FILES" assert e.size == 1 assert e.chunks == 1 - assert e.digest == 'SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=' + assert e.digest == "SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=" e = await watcher.updates() - assert e.name == 'B' - assert e.bucket == 'TEST_FILES' + assert e.name == "B" + assert e.bucket == "TEST_FILES" assert e.size == 1 assert e.chunks == 1 - assert e.digest == 'SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=' + assert e.digest == "SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=" e = await watcher.updates() - assert e.name == 'C' - assert e.bucket == 'TEST_FILES' + assert e.name == "C" + assert e.bucket == "TEST_FILES" assert e.size == 1 assert e.chunks == 1 - assert e.digest == 'SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=' + assert e.digest == "SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=" # Expect no more updates. with pytest.raises(asyncio.TimeoutError): @@ -3786,16 +3829,16 @@ async def error_handler(e): assert deleted.info.name == "B" assert deleted.info.deleted == True - info = await obs.put("C", b'CCC') - assert info.name == 'C' + info = await obs.put("C", b"CCC") + assert info.name == "C" assert info.deleted == False e = await watcher.updates() - assert e.name == 'C' + assert e.name == "C" assert e.deleted == False res = await obs.get("C") - assert res.data == b'CCC' + assert res.data == b"CCC" # Update meta to_update_meta = res.info.meta @@ -3803,8 +3846,8 @@ async def error_handler(e): await obs.update_meta("C", to_update_meta) e = await watcher.updates() - assert e.name == 'C' - assert e.description == 'changed' + assert e.name == "C" + assert e.description == "changed" # Try to update meta when it has already been deleted. deleted_meta = deleted.info.meta @@ -3818,7 +3861,7 @@ async def error_handler(e): # Update meta res = await obs.get("A") - assert res.data == b'A' + assert res.data == b"A" to_update_meta = res.info.meta to_update_meta.name = "Z" to_update_meta.description = "changed" @@ -3840,20 +3883,20 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_LIST", - config=nats.js.api.ObjectStoreConfig(description="listing", ) + config=nats.js.api.ObjectStoreConfig(description="listing", ), ) await asyncio.gather( - obs.put("A", b'AAA'), - obs.put("B", b'BBB'), - obs.put("C", b'CCC'), - obs.put("D", b'DDD'), + obs.put("A", b"AAA"), + obs.put("B", b"BBB"), + obs.put("C", b"CCC"), + obs.put("D", b"DDD"), ) entries = await obs.list() assert len(entries) == 4 - assert entries[0].name == 'A' - assert entries[1].name == 'B' - assert entries[2].name == 'C' - assert entries[3].name == 'D' + assert entries[0].name == "A" + assert entries[1].name == "B" + assert entries[2].name == "C" + assert entries[3].name == "D" await nc.close() @@ -3875,13 +3918,13 @@ async def error_handler(e): # Create an 8MB object. obs = await js.create_object_store(bucket="big") - ls = ''.join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) + ls = "".join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) w = io.BytesIO(ls.encode()) info = await obs.put("big", w) assert info.name == "big" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Create actual file and put it in a bucket. tmp = tempfile.NamedTemporaryFile(delete=False) @@ -3893,14 +3936,14 @@ async def error_handler(e): assert info.name == "tmp" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" async with aiofiles.open(tmp.name) as f: info = await obs.put("tmp2", f) assert info.name == "tmp2" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" obr = await obs.get("tmp") info = obr.info @@ -3908,7 +3951,7 @@ async def error_handler(e): assert info.size == 1048609 assert len(obr.data) == info.size # no reader reads whole file. assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Using a local file. async with aiofiles.open("pyproject.toml") as f: @@ -3933,7 +3976,7 @@ async def test_number_of_consumer_replicas(self): js = nc.jetstream() await js.add_stream(name="TESTREPLICAS", subjects=["test.replicas"]) for i in range(0, 10): - await js.publish("test.replicas", f'{i}'.encode()) + await js.publish("test.replicas", f"{i}".encode()) # Create consumer config = nats.js.api.ConsumerConfig( @@ -3957,14 +4000,20 @@ async def test_account_limits(self): with pytest.raises(BadRequestError) as err: await js.add_stream(name="limits", subjects=["limits"]) assert err.value.err_code == 10113 - assert err.value.description == "account requires a stream config to have max bytes set" + assert ( + err.value.description == + "account requires a stream config to have max bytes set" + ) with pytest.raises(BadRequestError) as err: await js.add_stream( name="limits", subjects=["limits"], max_bytes=65536 ) assert err.value.err_code == 10122 - assert err.value.description == "stream max bytes exceeds account limit max stream bytes" + assert ( + err.value.description == + "stream max bytes exceeds account limit max stream bytes" + ) si = await js.add_stream( name="limits", subjects=["limits"], max_bytes=128 @@ -3973,13 +4022,13 @@ async def test_account_limits(self): await js.publish( "limits", - b'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' + b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" ) si = await js.stream_info("limits") assert si.state.messages == 1 for i in range(0, 5): - await js.publish("limits", b'A') + await js.publish("limits", b"A") expected = nats.js.api.AccountInfo( memory=0, @@ -3994,11 +4043,11 @@ async def test_account_limits(self): max_ack_pending=100, memory_max_stream_bytes=2048, storage_max_stream_bytes=4096, - max_bytes_required=True + max_bytes_required=True, ), api=nats.js.api.APIStats(total=4, errors=2), - domain='test-domain', - tiers=None + domain="test-domain", + tiers=None, ) info = await js.account_info() assert expected == info @@ -4079,12 +4128,12 @@ async def test_account_limits(self): max_ack_pending=0, memory_max_stream_bytes=0, storage_max_stream_bytes=0, - max_bytes_required=False + max_bytes_required=False, ), api=nats.js.api.APIStats(total=6, errors=0), - domain='ngs', + domain="ngs", tiers={ - 'R1': + "R1": nats.js.api.Tier( memory=0, storage=6829550, @@ -4098,10 +4147,10 @@ async def test_account_limits(self): max_ack_pending=-1, memory_max_stream_bytes=-1, storage_max_stream_bytes=-1, - max_bytes_required=True - ) + max_bytes_required=True, + ), ), - 'R3': + "R3": nats.js.api.Tier( memory=0, storage=0, @@ -4115,10 +4164,10 @@ async def test_account_limits(self): max_ack_pending=-1, memory_max_stream_bytes=-1, storage_max_stream_bytes=-1, - max_bytes_required=True - ) - ) - } + max_bytes_required=True, + ), + ), + }, ) info = nats.js.api.AccountInfo.from_response(json.loads(blob)) assert expected == info @@ -4142,7 +4191,7 @@ async def test_subject_transforms(self): ), ) for i in range(0, 10): - await js.publish("test", f'{i}'.encode()) + await js.publish("test", f"{i}".encode()) # Creating a filtered consumer will be successful but will not be able to match # so num_pending would remain at zeroes. @@ -4164,10 +4213,10 @@ async def test_subject_transforms(self): assert cinfo.num_pending == 10 msgs = await psub.fetch(10, timeout=1) assert len(msgs) == 10 - assert msgs[0].subject == 'transformed.test' - assert msgs[0].data == b'0' - assert msgs[5].subject == 'transformed.test' - assert msgs[5].data == b'5' + assert msgs[0].subject == "transformed.test" + assert msgs[0].data == b"0" + assert msgs[5].subject == "transformed.test" + assert msgs[5].data == b"5" # Create stream that fetches from the original stream that # already has the transformed subjects. @@ -4198,10 +4247,10 @@ async def test_subject_transforms(self): assert cinfo.num_pending == 10 msgs = await psub.fetch(10, timeout=1) assert len(msgs) == 10 - assert msgs[0].subject == 'fromtest.transformed.test' - assert msgs[0].data == b'0' - assert msgs[5].subject == 'fromtest.transformed.test' - assert msgs[5].data == b'5' + assert msgs[0].subject == "fromtest.transformed.test" + assert msgs[0].data == b"0" + assert msgs[5].subject == "fromtest.transformed.test" + assert msgs[5].data == b"5" # Source is overlapping so should fail. transformed_source = nats.js.api.StreamSource( @@ -4241,7 +4290,7 @@ async def test_stream_compression(self): subjects=["test", "foo"], compression="s3", ) - assert str(err.value) == 'nats: invalid store compression type: s3' + assert str(err.value) == "nats: invalid store compression type: s3" # An empty string means not setting compression, but to be explicit # can also use @@ -4272,10 +4321,10 @@ async def test_stream_consumer_metadata(self): await js.add_stream( name="META", subjects=["test", "foo"], - metadata={'foo': 'bar'}, + metadata={"foo": "bar"}, ) sinfo = await js.stream_info("META") - assert sinfo.config.metadata['foo'] == 'bar' + assert sinfo.config.metadata["foo"] == "bar" with pytest.raises(ValueError) as err: await js.add_stream( @@ -4283,16 +4332,16 @@ async def test_stream_consumer_metadata(self): subjects=["test", "foo"], metadata=["hello", "world"], ) - assert str(err.value) == 'nats: invalid metadata format' + assert str(err.value) == "nats: invalid metadata format" await js.add_consumer( "META", config=nats.js.api.ConsumerConfig( - durable_name="b", metadata={'hello': 'world'} - ) + durable_name="b", metadata={"hello": "world"} + ), ) cinfo = await js.consumer_info("META", "b") - assert cinfo.config.metadata['hello'] == 'world' + assert cinfo.config.metadata["hello"] == "world" await nc.close() @@ -4307,7 +4356,7 @@ async def test_fetch_pull_subscribe_bind(self): await js.add_stream(name=stream_name, subjects=["foo", "bar"]) for i in range(0, 5): - await js.publish("foo", b'A') + await js.publish("foo", b"A") # Fetch with multiple filters on an ephemeral consumer. cinfo = await js.add_consumer( @@ -4361,16 +4410,17 @@ async def test_fetch_pull_subscribe_bind(self): with pytest.raises(ValueError) as err: await js.pull_subscribe_bind(durable=cinfo.name) - assert str(err.value) == 'nats: stream name is required' + assert str(err.value) == "nats: stream name is required" with pytest.raises(ValueError) as err: await js.pull_subscribe_bind(cinfo.name) - assert str(err.value) == 'nats: stream name is required' + assert str(err.value) == "nats: stream name is required" await nc.close() class BadStreamNamesTest(SingleJetStreamServerTestCase): + @async_test async def test_add_stream_invalid_names(self): nc = NATS() @@ -4391,10 +4441,10 @@ async def test_add_stream_invalid_names(self): for name in invalid_names: with pytest.raises( - ValueError, - match=( - f"nats: stream name \\({re.escape(name)}\\) is invalid. Names cannot contain whitespace, '\\.', " - "'\\*', '>', path separators \\(forward or backward slash\\), or non-printable characters." - ), + ValueError, + match= + (f"nats: stream name \\({re.escape(name)}\\) is invalid. Names cannot contain whitespace, '\\.', " + "'\\*', '>', path separators \\(forward or backward slash\\), or non-printable characters." + ), ): await js.add_stream(name=name) diff --git a/tests/test_micro_service.py b/tests/test_micro_service.py index c1227222..88d21e42 100644 --- a/tests/test_micro_service.py +++ b/tests/test_micro_service.py @@ -12,23 +12,29 @@ class MicroServiceTest(SingleServerTestCase): + def test_invalid_service_name(self): with self.assertRaises(ValueError) as context: ServiceConfig(name="", version="0.1.0") self.assertEqual(str(context.exception), "Name cannot be empty.") - with self.assertRaises(ValueError) as context: ServiceConfig(name="test.service@!", version="0.1.0") - self.assertEqual(str(context.exception), "Invalid name. It must contain only alphanumeric characters, dashes, and underscores.") - + self.assertEqual( + str(context.exception), + "Invalid name. It must contain only alphanumeric characters, dashes, and underscores.", + ) def test_invalid_service_version(self): with self.assertRaises(ValueError) as context: ServiceConfig(name="test_service", version="abc") - self.assertEqual(str(context.exception), "Invalid version. It must follow semantic versioning (e.g., 1.0.0, 2.1.3-alpha.1).") + self.assertEqual( + str(context.exception), + "Invalid version. It must follow semantic versioning (e.g., 1.0.0, 2.1.3-alpha.1).", + ) def test_invalid_endpoint_subject(self): + async def noop_handler(request: Request) -> None: pass @@ -38,7 +44,10 @@ async def noop_handler(request: Request) -> None: subject="endpoint subject", handler=noop_handler, ) - self.assertEqual(str(context.exception), "Invalid subject. Subject must not contain spaces, and can only have '>' at the end.") + self.assertEqual( + str(context.exception), + "Invalid subject. Subject must not contain spaces, and can only have '>' at the end.", + ) @async_test async def test_service_basics(self): @@ -70,7 +79,13 @@ async def add_handler(request: Request): svcs.append(svc) for _ in range(50): - await nc.request("svc.add", json.dumps({"x": 22, "y": 11}).encode("utf-8")) + await nc.request( + "svc.add", + json.dumps({ + "x": 22, + "y": 11 + }).encode("utf-8") + ) for svc in svcs: info = svc.info() @@ -92,7 +107,9 @@ async def add_handler(request: Request): ping_responses = [] while True: try: - ping_responses.append(await ping_subscription.next_msg(timeout=0.25)) + ping_responses.append( + await ping_subscription.next_msg(timeout=0.25) + ) except: break @@ -107,7 +124,9 @@ async def add_handler(request: Request): stats_responses = [] while True: try: - stats_responses.append(await stats_subscription.next_msg(timeout=0.25)) + stats_responses.append( + await stats_subscription.next_msg(timeout=0.25) + ) except: break @@ -116,34 +135,40 @@ async def add_handler(request: Request): ServiceStats.from_dict(json.loads(response.data.decode())) for response in stats_responses ] - total_requests = sum([stat.endpoints[0].num_requests for stat in stats]) + total_requests = sum([ + stat.endpoints[0].num_requests for stat in stats + ]) assert total_requests == 50 @async_test async def test_add_service(self): + async def noop_handler(request: Request): pass sub_tests = { "no_endpoint": { - "service_config": ServiceConfig( - name="test_service", - version="0.1.0", - metadata={"basic": "metadata"}, - ), - "expected_ping": ServicePing( - id="*", - type="io.nats.micro.v1.ping_response", - name="test_service", - version="0.1.0", - metadata={"basic": "metadata"}, - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.1.0", + metadata={"basic": "metadata"}, + ), + "expected_ping": + ServicePing( + id="*", + type="io.nats.micro.v1.ping_response", + name="test_service", + version="0.1.0", + metadata={"basic": "metadata"}, + ), }, "with_single_endpoint": { - "service_config": ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="test", @@ -152,18 +177,20 @@ async def noop_handler(request: Request): metadata={"basic": "endpoint_metadata"}, ), ], - "expected_ping": ServicePing( - id="*", - name="test_service", - version="0.1.0", - metadata={}, - ), + "expected_ping": + ServicePing( + id="*", + name="test_service", + version="0.1.0", + metadata={}, + ), }, "with_multiple_endpoints": { - "service_config": ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -178,12 +205,13 @@ async def noop_handler(request: Request): handler=noop_handler, ), ], - "expected_ping": ServicePing( - id="*", - name="test_service", - version="0.1.0", - metadata={}, - ), + "expected_ping": + ServicePing( + id="*", + name="test_service", + version="0.1.0", + metadata={}, + ), }, } @@ -217,37 +245,51 @@ async def noop_handler(request: Request): await svc.stop() assert svc.stopped - @async_test async def test_groups(self): sub_tests = { "no_groups": { "name": "no groups", "endpoint_name": "foo", - "expected_endpoint": {"name": "foo", "subject": "foo"}, + "expected_endpoint": { + "name": "foo", + "subject": "foo" + }, }, "single_group": { "name": "single group", "endpoint_name": "foo", "group_names": ["g1"], - "expected_endpoint": {"name": "foo", "subject": "g1.foo"}, + "expected_endpoint": { + "name": "foo", + "subject": "g1.foo" + }, }, "single_empty_group": { "name": "single empty group", "endpoint_name": "foo", "group_names": [""], - "expected_endpoint": {"name": "foo", "subject": "foo"}, + "expected_endpoint": { + "name": "foo", + "subject": "foo" + }, }, "empty_groups": { "name": "empty groups", "endpoint_name": "foo", "group_names": ["", "g1", ""], - "expected_endpoint": {"name": "foo", "subject": "g1.foo"}, + "expected_endpoint": { + "name": "foo", + "subject": "g1.foo" + }, }, "multiple_groups": { "endpoint_name": "foo", "group_names": ["g1", "g2", "g3"], - "expected_endpoint": {"name": "foo", "subject": "g1.g2.g3.foo"}, + "expected_endpoint": { + "name": "foo", + "subject": "g1.g2.g3.foo" + }, }, } @@ -267,7 +309,9 @@ async def noop_handler(_): for group_name in data.get("group_names", []): group = group.add_group(name=group_name) - await group.add_endpoint(name=data["endpoint_name"], handler=noop_handler) + await group.add_endpoint( + name=data["endpoint_name"], handler=noop_handler + ) info = svc.info() assert info.endpoints @@ -282,6 +326,7 @@ async def noop_handler(_): @async_test async def test_monitoring_handlers(self): + async def noop_handler(request: Request): pass @@ -342,14 +387,14 @@ async def noop_handler(request: Request): "description": None, "version": "0.1.0", "id": svc.id, - "endpoints": [ - { - "name": "default", - "subject": "test.func", - "queue_group": "q", - "metadata": {"basic": "schema"}, - } - ], + "endpoints": [{ + "name": "default", + "subject": "test.func", + "queue_group": "q", + "metadata": { + "basic": "schema" + }, + }], "metadata": {}, }, }, @@ -361,14 +406,14 @@ async def noop_handler(request: Request): "description": None, "version": "0.1.0", "id": svc.id, - "endpoints": [ - { - "name": "default", - "subject": "test.func", - "queue_group": "q", - "metadata": {"basic": "schema"}, - } - ], + "endpoints": [{ + "name": "default", + "subject": "test.func", + "queue_group": "q", + "metadata": { + "basic": "schema" + }, + }], "metadata": {}, }, }, @@ -380,14 +425,14 @@ async def noop_handler(request: Request): "description": None, "version": "0.1.0", "id": svc.id, - "endpoints": [ - { - "name": "default", - "subject": "test.func", - "queue_group": "q", - "metadata": {"basic": "schema"}, - } - ], + "endpoints": [{ + "name": "default", + "subject": "test.func", + "queue_group": "q", + "metadata": { + "basic": "schema" + }, + }], "metadata": {}, }, }, @@ -404,15 +449,17 @@ async def noop_handler(request: Request): @async_test async def test_service_stats(self): + async def handler(request: Request): await request.respond(b"ok") sub_tests = { "stats_handler": { - "service_config": ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="default", @@ -423,11 +470,12 @@ async def handler(request: Request): ], }, "with_stats_handler": { - "service_config": ServiceConfig( - name="test_service", - version="0.1.0", - stats_handler=lambda endpoint: {"key": "val"}, - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.1.0", + stats_handler=lambda endpoint: {"key": "val"}, + ), "endpoint_configs": [ EndpointConfig( name="default", @@ -436,13 +484,16 @@ async def handler(request: Request): metadata={"test": "value"}, ) ], - "expected_stats": {"key": "val"}, + "expected_stats": { + "key": "val" + }, }, "with_endpoint": { - "service_config": ServiceConfig( - name="test_service", - version="0.1.0", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.1.0", + ), "endpoint_configs": [ EndpointConfig( name="default", @@ -472,8 +523,12 @@ async def handler(request: Request): info = svc.info() - stats_subject = control_subject(ServiceVerb.STATS, "test_service") - stats_response = await nc.request(stats_subject, b"", timeout=1) + stats_subject = control_subject( + ServiceVerb.STATS, "test_service" + ) + stats_response = await nc.request( + stats_subject, b"", timeout=1 + ) stats = ServiceStats.from_dict(json.loads(stats_response.data)) assert len(stats.endpoints) == len(info.endpoints) @@ -501,10 +556,14 @@ async def test_request_respond(self): "expected_response": b"OK", }, "byte_response_with_headers": { - "respond_headers": {"key": "value"}, + "respond_headers": { + "key": "value" + }, "respond_data": b"OK", "expected_response": b"OK", - "expected_headers": {"key": "value"}, + "expected_headers": { + "key": "value" + }, }, } @@ -527,7 +586,9 @@ async def handler(request: Request): ), ) await svc.add_endpoint( - EndpointConfig(name="default", subject="test.func", handler=handler) + EndpointConfig( + name="default", subject="test.func", handler=handler + ) ) response = await nc.request( @@ -571,15 +632,17 @@ def test_control_subject(self): @async_test async def test_custom_queue_group(self): + async def noop_handler(request: Request): pass sub_tests = { "default_queue_group": { - "service_config": ServiceConfig( - name="test_service", - version="0.0.1", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.0.1", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -591,11 +654,12 @@ async def noop_handler(request: Request): }, }, "custom_queue_group_on_service_config": { - "service_config": ServiceConfig( - name="test_service", - version="0.0.1", - queue_group="custom", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.0.1", + queue_group="custom", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -608,11 +672,12 @@ async def noop_handler(request: Request): }, }, "endpoint_config_overriding_queue_groups": { - "service_config": ServiceConfig( - name="test_service", - version="0.0.1", - queue_group="q-config", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.0.1", + queue_group="q-config", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -626,11 +691,12 @@ async def noop_handler(request: Request): }, "empty_queue_group_in_option_inherit_from_parent": { "name": "empty queue group in option, inherit from parent", - "service_config": ServiceConfig( - name="test_service", - version="0.0.1", - queue_group="q-service", - ), + "service_config": + ServiceConfig( + name="test_service", + version="0.0.1", + queue_group="q-service", + ), "endpoint_configs": [ EndpointConfig( name="foo", @@ -661,11 +727,12 @@ async def noop_handler(request: Request): info = svc.info() - assert len(info.endpoints) == len(data["expected_queue_groups"]) + assert len(info.endpoints + ) == len(data["expected_queue_groups"]) for endpoint in info.endpoints: assert ( - endpoint.queue_group - == data["expected_queue_groups"][endpoint.name] + endpoint.queue_group == data["expected_queue_groups"][ + endpoint.name] ) await svc.stop() diff --git a/tests/test_parser.py b/tests/test_parser.py index d5fda2d5..6f032bb9 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -46,7 +46,7 @@ def setUp(self): @async_test async def test_parse_ping(self): ps = Parser(MockNatsClient()) - data = b'PING\r\n' + data = b"PING\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -54,7 +54,7 @@ async def test_parse_ping(self): @async_test async def test_parse_pong(self): ps = Parser(MockNatsClient()) - data = b'PONG\r\n' + data = b"PONG\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -62,7 +62,7 @@ async def test_parse_pong(self): @async_test async def test_parse_ok(self): ps = Parser() - data = b'+OK\r\n' + data = b"+OK\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -70,7 +70,7 @@ async def test_parse_ok(self): @async_test async def test_parse_msg(self): nc = MockNatsClient() - expected = b'hello world!' + expected = b"hello world!" async def payload_test(sid, subject, reply, payload): self.assertEqual(payload, expected) @@ -84,22 +84,22 @@ async def payload_test(sid, subject, reply, payload): sub = Subscription(nc, 1, **params) nc._subs[1] = sub ps = Parser(nc) - data = b'MSG hello 1 world 12\r\n' + data = b"MSG hello 1 world 12\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(len(ps.msg_arg.keys()), 3) - self.assertEqual(ps.msg_arg["subject"], b'hello') - self.assertEqual(ps.msg_arg["reply"], b'world') + self.assertEqual(ps.msg_arg["subject"], b"hello") + self.assertEqual(ps.msg_arg["reply"], b"world") self.assertEqual(ps.msg_arg["sid"], 1) self.assertEqual(ps.needed, 12) self.assertEqual(ps.state, AWAITING_MSG_PAYLOAD) - await ps.parse(b'hello world!') + await ps.parse(b"hello world!") self.assertEqual(len(ps.buf), 12) self.assertEqual(ps.state, AWAITING_MSG_PAYLOAD) - data = b'\r\n' + data = b"\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -107,7 +107,7 @@ async def payload_test(sid, subject, reply, payload): @async_test async def test_parse_msg_op(self): ps = Parser() - data = b'MSG hello' + data = b"MSG hello" await ps.parse(data) self.assertEqual(len(ps.buf), 9) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -115,7 +115,7 @@ async def test_parse_msg_op(self): @async_test async def test_parse_split_msg_op(self): ps = Parser() - data = b'MSG' + data = b"MSG" await ps.parse(data) self.assertEqual(len(ps.buf), 3) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -123,7 +123,7 @@ async def test_parse_split_msg_op(self): @async_test async def test_parse_split_msg_op_space(self): ps = Parser() - data = b'MSG ' + data = b"MSG " await ps.parse(data) self.assertEqual(len(ps.buf), 4) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -131,7 +131,7 @@ async def test_parse_split_msg_op_space(self): @async_test async def test_parse_split_msg_op_wrong_args(self): ps = Parser(MockNatsClient()) - data = b'MSG PONG\r\n' + data = b"MSG PONG\r\n" with self.assertRaises(ProtocolError): await ps.parse(data) @@ -160,12 +160,12 @@ async def test_parse_err(self): async def test_parse_info(self): nc = MockNatsClient() ps = Parser(nc) - server_id = 'A' * 2048 + server_id = "A" * 2048 data = f'INFO {{"server_id": "{server_id}", "max_payload": 100, "auth_required": false, "connect_urls":["127.0.0.0.1:4223"]}}\r\n' await ps.parse(data.encode()) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - self.assertEqual(len(nc._server_info['server_id']), 2048) + self.assertEqual(len(nc._server_info["server_id"]), 2048) @async_test async def test_parse_msg_long_subject_reply(self): @@ -187,14 +187,14 @@ async def payload_test(sid, subject, reply, payload): nc._subs[1] = sub ps = Parser(nc) - reply = 'A' * 2043 - data = f'PING\r\nMSG hello 1 {reply}' + reply = "A" * 2043 + data = f"PING\r\nMSG hello 1 {reply}" await ps.parse(data.encode()) - await ps.parse(b'AAAAA 0\r\n\r\nMSG hello 1 world 0') + await ps.parse(b"AAAAA 0\r\n\r\nMSG hello 1 world 0") self.assertEqual(msgs, 1) self.assertEqual(len(ps.buf), 19) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - await ps.parse(b'\r\n\r\n') + await ps.parse(b"\r\n\r\n") self.assertEqual(msgs, 2) @async_test @@ -217,24 +217,24 @@ async def payload_test(sid, subject, reply, payload): nc._subs[1] = sub ps = Parser(nc) - reply = 'A' * 2043 * 4 + reply = "A" * 2043 * 4 # FIXME: Malformed long protocol lines will not be detected # by the client in this case so we rely on the ping/pong interval # from the client to give up instead. - data = f'PING\r\nWRONG hello 1 {reply}' + data = f"PING\r\nWRONG hello 1 {reply}" await ps.parse(data.encode()) - await ps.parse(b'AAAAA 0') + await ps.parse(b"AAAAA 0") self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - await ps.parse(b'\r\n\r\n') - await ps.parse(b'\r\n\r\n') + await ps.parse(b"\r\n\r\n") + await ps.parse(b"\r\n\r\n") ps = Parser(nc) - reply = 'A' * 2043 - data = f'PING\r\nWRONG hello 1 {reply}' + reply = "A" * 2043 + data = f"PING\r\nWRONG hello 1 {reply}" with self.assertRaises(ProtocolError): await ps.parse(data.encode()) - await ps.parse(b'AAAAA 0') + await ps.parse(b"AAAAA 0") self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - await ps.parse(b'\r\n\r\n') - await ps.parse(b'\r\n\r\n') + await ps.parse(b"\r\n\r\n") + await ps.parse(b"\r\n\r\n") diff --git a/tests/utils.py b/tests/utils.py index 43dd27fa..d200937a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,7 @@ try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: pass @@ -72,9 +73,13 @@ def start(self): self.bin_name = str(Path(THIS_DIR).joinpath(self.bin_name)) cmd = [ - self.bin_name, "-p", - "%d" % self.port, "-m", - "%d" % self.http_port, "-a", "127.0.0.1" + self.bin_name, + "-p", + "%d" % self.port, + "-m", + "%d" % self.http_port, + "-a", + "127.0.0.1", ] if self.user: cmd.append("--user") @@ -95,24 +100,24 @@ def start(self): cmd.append(f"-sd={self.store_dir}") if self.tls: - cmd.append('--tls') - cmd.append('--tlscert') - cmd.append(get_config_file('certs/server-cert.pem')) - cmd.append('--tlskey') - cmd.append(get_config_file('certs/server-key.pem')) - cmd.append('--tlsverify') - cmd.append('--tlscacert') - cmd.append(get_config_file('certs/ca.pem')) + cmd.append("--tls") + cmd.append("--tlscert") + cmd.append(get_config_file("certs/server-cert.pem")) + cmd.append("--tlskey") + cmd.append(get_config_file("certs/server-key.pem")) + cmd.append("--tlsverify") + cmd.append("--tlscacert") + cmd.append(get_config_file("certs/ca.pem")) if self.cluster_listen is not None: - cmd.append('--cluster_listen') + cmd.append("--cluster_listen") cmd.append(self.cluster_listen) if len(self.routes) > 0: - cmd.append('--routes') - cmd.append(','.join(self.routes)) - cmd.append('--cluster_name') - cmd.append('CLUSTER') + cmd.append("--routes") + cmd.append(",".join(self.routes)) + cmd.append("--cluster_name") + cmd.append("CLUSTER") if self.config_file is not None: cmd.append("--config") @@ -242,10 +247,10 @@ def setUp(self): purpose=ssl.Purpose.SERVER_AUTH ) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) def tearDown(self): @@ -261,7 +266,7 @@ def setUp(self): self.natsd = NATSD( port=4224, - config_file=get_config_file('conf/tls_handshake_first.conf') + config_file=get_config_file("conf/tls_handshake_first.conf") ) start_natsd(self.natsd) @@ -269,10 +274,10 @@ def setUp(self): purpose=ssl.Purpose.SERVER_AUTH ) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) def tearDown(self): @@ -302,10 +307,10 @@ def setUp(self): purpose=ssl.Purpose.SERVER_AUTH ) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) def tearDown(self): @@ -322,14 +327,15 @@ def setUp(self): self.loop = asyncio.new_event_loop() routes = [ - "nats://127.0.0.1:6223", "nats://127.0.0.1:6224", - "nats://127.0.0.1:6225" + "nats://127.0.0.1:6223", + "nats://127.0.0.1:6224", + "nats://127.0.0.1:6225", ] server1 = NATSD( port=4223, http_port=8223, cluster_listen="nats://127.0.0.1:6223", - routes=routes + routes=routes, ) self.server_pool.append(server1) @@ -337,7 +343,7 @@ def setUp(self): port=4224, http_port=8224, cluster_listen="nats://127.0.0.1:6224", - routes=routes + routes=routes, ) self.server_pool.append(server2) @@ -345,7 +351,7 @@ def setUp(self): port=4225, http_port=8225, cluster_listen="nats://127.0.0.1:6225", - routes=routes + routes=routes, ) self.server_pool.append(server3) @@ -377,7 +383,7 @@ def setUp(self): password="bar", http_port=8223, cluster_listen="nats://127.0.0.1:6223", - routes=routes + routes=routes, ) self.server_pool.append(server1) @@ -387,7 +393,7 @@ def setUp(self): password="bar", http_port=8224, cluster_listen="nats://127.0.0.1:6224", - routes=routes + routes=routes, ) self.server_pool.append(server2) @@ -397,7 +403,7 @@ def setUp(self): password="bar", http_port=8225, cluster_listen="nats://127.0.0.1:6225", - routes=routes + routes=routes, ) self.server_pool.append(server3) @@ -497,10 +503,10 @@ def setUp(self): self.ssl_ctx = ssl.create_default_context( purpose=ssl.Purpose.SERVER_AUTH ) - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) server = NATSD( @@ -527,7 +533,7 @@ def setUp(self): server = NATSD( port=4222, with_jetstream=True, - config_file=get_config_file("conf/js-limits.conf") + config_file=get_config_file("conf/js-limits.conf"), ) self.server_pool.append(server) for natsd in self.server_pool: @@ -562,7 +568,7 @@ def tearDown(self): def start_natsd(natsd: NATSD): natsd.start() - endpoint = f'127.0.0.1:{natsd.http_port}' + endpoint = f"127.0.0.1:{natsd.http_port}" retries = 0 while True: if retries > 100: @@ -570,7 +576,7 @@ def start_natsd(natsd: NATSD): try: httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/varz') + httpclient.request("GET", "/varz") response = httpclient.getresponse() if response.status == 200: break @@ -606,6 +612,7 @@ def wrapper(test_case, *args, **kw): return wrapper + def async_debug_test(test_case_fun, timeout=3600): @wraps(test_case_fun)