From 5317d0322e95147fbc9eb77290fe804e3e643043 Mon Sep 17 00:00:00 2001 From: Casper Beyer Date: Wed, 25 Sep 2024 13:57:31 +0200 Subject: [PATCH] Run ruff format --- benchmark/inbox_perf.py | 10 +- benchmark/latency_perf.py | 105 +-- benchmark/parser_perf.py | 38 +- benchmark/pub_perf.py | 46 +- benchmark/pub_sub_perf.py | 50 +- benchmark/sub_perf.py | 16 +- examples/advanced.py | 33 +- examples/basic.py | 21 +- examples/client.py | 18 +- examples/clustered.py | 9 +- examples/common/args.py | 25 +- examples/component.py | 25 +- examples/connect.py | 39 +- examples/context-manager.py | 13 +- examples/drain-sub.py | 30 +- examples/example.py | 20 +- examples/jetstream.py | 4 +- examples/kv.py | 21 +- examples/nats-pub/__main__.py | 20 +- examples/nats-req/__main__.py | 14 +- examples/nats-sub/__main__.py | 20 +- examples/publish.py | 16 +- examples/service.py | 30 +- examples/simple.py | 5 +- examples/subscribe.py | 25 +- examples/tls.py | 36 +- examples/wildcard.py | 17 +- nats/__init__.py | 3 +- nats/aio/client.py | 339 +++++---- nats/aio/errors.py | 20 + nats/aio/msg.py | 45 +- nats/aio/subscription.py | 34 +- nats/aio/transport.py | 32 +- nats/errors.py | 35 +- nats/js/api.py | 125 ++-- nats/js/client.py | 144 ++-- nats/js/errors.py | 23 +- nats/js/kv.py | 23 +- nats/js/manager.py | 108 ++- nats/js/object_store.py | 53 +- nats/nuid.py | 2 +- nats/protocol/command.py | 31 +- nats/protocol/parser.py | 84 +-- tests/test_client.py | 876 +++++++++++----------- tests/test_client_async_await.py | 67 +- tests/test_client_nkeys.py | 55 +- tests/test_client_v2.py | 50 +- tests/test_client_websocket.py | 132 ++-- tests/test_js.py | 1206 +++++++++++++++--------------- tests/test_micro_service.py | 18 +- tests/test_nuid.py | 9 +- tests/test_parser.py | 60 +- tests/utils.py | 150 ++-- 53 files changed, 2194 insertions(+), 2236 deletions(-) diff --git a/benchmark/inbox_perf.py b/benchmark/inbox_perf.py index 0ef756ae..a882c39b 100644 --- a/benchmark/inbox_perf.py +++ b/benchmark/inbox_perf.py @@ -2,21 +2,23 @@ from nats.nuid import NUID -INBOX_PREFIX = bytearray(b'_INBOX.') +INBOX_PREFIX = bytearray(b"_INBOX.") + def gen_inboxes_nuid(n): nuid = NUID() for i in range(0, n): inbox = INBOX_PREFIX[:] inbox.extend(nuid.next()) - inbox.extend(b'.') + inbox.extend(b".") inbox.extend(nuid.next()) -if __name__ == '__main__': + +if __name__ == "__main__": benchs = [ "gen_inboxes_nuid(100000)", "gen_inboxes_nuid(1000000)", - ] + ] for bench in benchs: print(f"=== {bench}") prof.run(bench) diff --git a/benchmark/latency_perf.py b/benchmark/latency_perf.py index 7ed3da7b..e177bc1d 100644 --- a/benchmark/latency_perf.py +++ b/benchmark/latency_perf.py @@ -10,66 +10,71 @@ HASH_MODULO = 250 try: - import uvloop - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: - pass + pass + def show_usage(): - message = """ + message = """ Usage: latency_perf [options] options: -n ITERATIONS Iterations to spec (default: 1000) -S SUBJECT Send subject (default: (test) """ - print(message) + print(message) + def show_usage_and_die(): - show_usage() - sys.exit(1) + show_usage() + sys.exit(1) + async def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-n', '--iterations', default=DEFAULT_ITERATIONS, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('--servers', default=[], action='append') - args = parser.parse_args() - - servers = args.servers - if len(args.servers) < 1: - servers = ["nats://127.0.0.1:4222"] - - try: - nc = await nats.connect(servers) - except Exception as e: - sys.stderr.write(f"ERROR: {e}") - show_usage_and_die() - - async def handler(msg): - await nc.publish(msg.reply, b'') - await nc.subscribe(args.subject, cb=handler) - - # Start the benchmark - start = time.monotonic() - to_send = args.iterations - - print("Sending {} request/responses on [{}]".format( - args.iterations, args.subject)) - while to_send > 0: - to_send -= 1 - if to_send == 0: - break - - await nc.request(args.subject, b'') - if (to_send % HASH_MODULO) == 0: - sys.stdout.write("#") - sys.stdout.flush() - - duration = time.monotonic() - start - ms = "%.3f" % ((duration/args.iterations) * 1000) - print(f"\nTest completed : {ms} ms avg request/response latency") - await nc.close() - -if __name__ == '__main__': - asyncio.run(main()) + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--iterations", default=DEFAULT_ITERATIONS, type=int) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("--servers", default=[], action="append") + args = parser.parse_args() + + servers = args.servers + if len(args.servers) < 1: + servers = ["nats://127.0.0.1:4222"] + + try: + nc = await nats.connect(servers) + except Exception as e: + sys.stderr.write(f"ERROR: {e}") + show_usage_and_die() + + async def handler(msg): + await nc.publish(msg.reply, b"") + + await nc.subscribe(args.subject, cb=handler) + + # Start the benchmark + start = time.monotonic() + to_send = args.iterations + + print("Sending {} request/responses on [{}]".format(args.iterations, args.subject)) + while to_send > 0: + to_send -= 1 + if to_send == 0: + break + + await nc.request(args.subject, b"") + if (to_send % HASH_MODULO) == 0: + sys.stdout.write("#") + sys.stdout.flush() + + duration = time.monotonic() - start + ms = "%.3f" % ((duration / args.iterations) * 1000) + print(f"\nTest completed : {ms} ms avg request/response latency") + await nc.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/parser_perf.py b/benchmark/parser_perf.py index 9ef0f53b..e67af94a 100644 --- a/benchmark/parser_perf.py +++ b/benchmark/parser_perf.py @@ -5,21 +5,20 @@ class DummyNatsClient: - def __init__(self): self._subs = {} self._pongs = [] self._pings_outstanding = 0 self._pongs_received = 0 - self._server_info = {"max_payload": 1048576, "auth_required": False } + self._server_info = {"max_payload": 1048576, "auth_required": False} self.stats = { - 'in_msgs': 0, - 'out_msgs': 0, - 'in_bytes': 0, - 'out_bytes': 0, - 'reconnects': 0, - 'errors_received': 0 - } + "in_msgs": 0, + "out_msgs": 0, + "in_bytes": 0, + "out_bytes": 0, + "reconnects": 0, + "errors_received": 0, + } async def _send_command(self, cmd): pass @@ -31,31 +30,34 @@ async def _process_ping(self): pass async def _process_msg(self, sid, subject, reply, data): - self.stats['in_msgs'] += 1 - self.stats['in_bytes'] += len(data) + self.stats["in_msgs"] += 1 + self.stats["in_bytes"] += len(data) async def _process_err(self, err=None): pass + def generate_msg(subject, nbytes, reply=""): msg = [] protocol_line = "MSG {subject} 1 {reply} {nbytes}\r\n".format( - subject=subject, reply=reply, nbytes=nbytes).encode() + subject=subject, reply=reply, nbytes=nbytes + ).encode() msg.append(protocol_line) - msg.append(b'A' * nbytes) - msg.append(b'r\n') - return b''.join(msg) + msg.append(b"A" * nbytes) + msg.append(b"r\n") + return b"".join(msg) + def parse_msgs(max_msgs=1, nbytes=1): - buf = b''.join([generate_msg("foo", nbytes) for i in range(0, max_msgs)]) + buf = b"".join([generate_msg("foo", nbytes) for i in range(0, max_msgs)]) print("--- buffer size: {}".format(len(buf))) loop = asyncio.get_event_loop() ps = Parser(DummyNatsClient()) loop.run_until_complete(ps.parse(buf)) print("--- stats: ", ps.nc.stats) -if __name__ == '__main__': +if __name__ == "__main__": benchs = [ "parse_msgs(max_msgs=10000, nbytes=1)", "parse_msgs(max_msgs=100000, nbytes=1)", @@ -71,7 +73,7 @@ def parse_msgs(max_msgs=1, nbytes=1): "parse_msgs(max_msgs=100000, nbytes=8192)", "parse_msgs(max_msgs=10000, nbytes=16384)", "parse_msgs(max_msgs=100000, nbytes=16384)", - ] + ] for bench in benchs: print(f"=== {bench}") diff --git a/benchmark/pub_perf.py b/benchmark/pub_perf.py index b90065c1..6e2c4f11 100644 --- a/benchmark/pub_perf.py +++ b/benchmark/pub_perf.py @@ -7,10 +7,11 @@ import nats try: - import uvloop - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: - pass + pass DEFAULT_FLUSH_TIMEOUT = 30 DEFAULT_NUM_MSGS = 100000 @@ -18,6 +19,7 @@ DEFAULT_BATCH_SIZE = 100 HASH_MODULO = 1000 + def show_usage(): message = """ Usage: pub_perf [options] @@ -30,24 +32,26 @@ def show_usage(): """ print(message) + def show_usage_and_die(): show_usage() sys.exit(1) + async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-n', '--count', default=DEFAULT_NUM_MSGS, type=int) - parser.add_argument('-s', '--size', default=DEFAULT_MSG_SIZE, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('-b', '--batch', default=DEFAULT_BATCH_SIZE, type=int) - parser.add_argument('--servers', default=[], action='append') + parser.add_argument("-n", "--count", default=DEFAULT_NUM_MSGS, type=int) + parser.add_argument("-s", "--size", default=DEFAULT_MSG_SIZE, type=int) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("-b", "--batch", default=DEFAULT_BATCH_SIZE, type=int) + parser.add_argument("--servers", default=[], action="append") args = parser.parse_args() data = [] for i in range(0, args.size): s = "%01x" % randint(0, 15) data.append(s.encode()) - payload = b''.join(data) + payload = b"".join(data) servers = args.servers if len(args.servers) < 1: @@ -55,7 +59,7 @@ async def main(): # Make sure we're connected to a server first.. try: - nc = await nats.connect(servers, pending_size=1024*1024) + nc = await nats.connect(servers, pending_size=1024 * 1024) except Exception as e: sys.stderr.write(f"ERROR: {e}") show_usage_and_die() @@ -64,8 +68,11 @@ async def main(): start = time.time() to_send = args.count - print("Sending {} messages of size {} bytes on [{}]".format( - args.count, args.size, args.subject)) + print( + "Sending {} messages of size {} bytes on [{}]".format( + args.count, args.size, args.subject + ) + ) while to_send > 0: for i in range(0, args.batch): to_send -= 1 @@ -86,11 +93,14 @@ async def main(): print(f"Server flush timeout after {DEFAULT_FLUSH_TIMEOUT}") elapsed = time.time() - start - mbytes = "%.1f" % (((args.size * args.count)/elapsed) / (1024*1024)) - print("\nTest completed : {} msgs/sec ({}) MB/sec".format( - args.count/elapsed, - mbytes)) + mbytes = "%.1f" % (((args.size * args.count) / elapsed) / (1024 * 1024)) + print( + "\nTest completed : {} msgs/sec ({}) MB/sec".format( + args.count / elapsed, mbytes + ) + ) await nc.close() -if __name__ == '__main__': - asyncio.run(main()) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/pub_sub_perf.py b/benchmark/pub_sub_perf.py index 996b5a76..e7a6a6cb 100644 --- a/benchmark/pub_sub_perf.py +++ b/benchmark/pub_sub_perf.py @@ -7,10 +7,11 @@ import nats try: - import uvloop - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: - pass + pass DEFAULT_FLUSH_TIMEOUT = 30 DEFAULT_NUM_MSGS = 100000 @@ -18,6 +19,7 @@ DEFAULT_BATCH_SIZE = 100 HASH_MODULO = 1000 + def show_usage(): message = """ Usage: pub_sub_perf [options] @@ -30,24 +32,26 @@ def show_usage(): """ print(message) + def show_usage_and_die(): show_usage() sys.exit(1) + async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-n', '--count', default=DEFAULT_NUM_MSGS, type=int) - parser.add_argument('-s', '--size', default=DEFAULT_MSG_SIZE, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('-b', '--batch', default=DEFAULT_BATCH_SIZE, type=int) - parser.add_argument('--servers', default=[], action='append') + parser.add_argument("-n", "--count", default=DEFAULT_NUM_MSGS, type=int) + parser.add_argument("-s", "--size", default=DEFAULT_MSG_SIZE, type=int) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("-b", "--batch", default=DEFAULT_BATCH_SIZE, type=int) + parser.add_argument("--servers", default=[], action="append") args = parser.parse_args() data = [] for i in range(0, args.size): s = "%01x" % randint(0, 15) data.append(s.encode()) - payload = b''.join(data) + payload = b"".join(data) servers = args.servers if len(args.servers) < 1: @@ -61,20 +65,25 @@ async def main(): show_usage_and_die() received = 0 + async def handler(msg): nonlocal received received += 1 if (received % HASH_MODULO) == 0: sys.stdout.write("*") sys.stdout.flush() + await nc.subscribe(args.subject, cb=handler) # Start the benchmark start = time.time() to_send = args.count - print("Sending {} messages of size {} bytes on [{}]".format( - args.count, args.size, args.subject)) + print( + "Sending {} messages of size {} bytes on [{}]".format( + args.count, args.size, args.subject + ) + ) while to_send > 0: for i in range(0, args.batch): to_send -= 1 @@ -97,13 +106,16 @@ async def handler(msg): print(f"Server flush timeout after {DEFAULT_FLUSH_TIMEOUT}") elapsed = time.time() - start - mbytes = "%.1f" % (((args.size * args.count)/elapsed) / (1024*1024)) - print("\nTest completed : {} msgs/sec sent ({}) MB/sec".format( - args.count/elapsed, - mbytes)) - - print("Received {} messages ({} msgs/sec)".format(received, received/elapsed)) + mbytes = "%.1f" % (((args.size * args.count) / elapsed) / (1024 * 1024)) + print( + "\nTest completed : {} msgs/sec sent ({}) MB/sec".format( + args.count / elapsed, mbytes + ) + ) + + print("Received {} messages ({} msgs/sec)".format(received, received / elapsed)) await nc.close() -if __name__ == '__main__': - asyncio.run(main()) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmark/sub_perf.py b/benchmark/sub_perf.py index 8ca47697..61e190a9 100644 --- a/benchmark/sub_perf.py +++ b/benchmark/sub_perf.py @@ -12,6 +12,7 @@ DEFAULT_BATCH_SIZE = 100 HASH_MODULO = 1000 + def show_usage(): message = """ Usage: sub_perf [options] @@ -22,15 +23,17 @@ def show_usage(): """ print(message) + def show_usage_and_die(): show_usage() sys.exit(1) + async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-n', '--count', default=DEFAULT_NUM_MSGS, type=int) - parser.add_argument('-S', '--subject', default='test') - parser.add_argument('--servers', default=[], action='append') + parser.add_argument("-n", "--count", default=DEFAULT_NUM_MSGS, type=int) + parser.add_argument("-S", "--subject", default="test") + parser.add_argument("--servers", default=[], action="append") args = parser.parse_args() servers = args.servers @@ -73,10 +76,11 @@ async def handler(msg): await asyncio.sleep(0.1) elapsed = time.monotonic() - start - print("\nTest completed : {} msgs/sec sent".format(args.count/elapsed)) + print("\nTest completed : {} msgs/sec sent".format(args.count / elapsed)) - print("Received {} messages ({} msgs/sec)".format(received, received/elapsed)) + print("Received {} messages ({} msgs/sec)".format(received, received / elapsed)) await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/advanced.py b/examples/advanced.py index 38e6ac8e..c9f81e6b 100644 --- a/examples/advanced.py +++ b/examples/advanced.py @@ -5,25 +5,26 @@ async def main(): async def disconnected_cb(): - print('Got disconnected!') + print("Got disconnected!") async def reconnected_cb(): - print(f'Got reconnected to {nc.connected_url.netloc}') + print(f"Got reconnected to {nc.connected_url.netloc}") async def error_cb(e): - print(f'There was an error: {e}') + print(f"There was an error: {e}") async def closed_cb(): - print('Connection is closed') + print("Connection is closed") try: # Setting explicit list of servers in a cluster. - nc = await nats.connect("localhost:4222", - error_cb=error_cb, - reconnected_cb=reconnected_cb, - disconnected_cb=disconnected_cb, - closed_cb=closed_cb, - ) + nc = await nats.connect( + "localhost:4222", + error_cb=error_cb, + reconnected_cb=reconnected_cb, + disconnected_cb=disconnected_cb, + closed_cb=closed_cb, + ) except NoServersError as e: print(e) return @@ -38,11 +39,14 @@ async def request_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # Signal the server to stop sending messages after we got 10 already. - resp = await nc.request("help.please", b'help') + resp = await nc.request("help.please", b"help") print("Response:", resp) try: @@ -58,5 +62,6 @@ async def request_handler(msg): # handle any pending messages inflight that the server may have sent. await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/basic.py b/examples/basic.py index 5f5a6f07..dc3f856e 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -19,25 +19,27 @@ async def message_handler(msg): # Stop receiving after 2 messages. await sub.unsubscribe(limit=2) - await nc.publish("foo", b'Hello') - await nc.publish("foo", b'World') - await nc.publish("foo", b'!!!!!') + await nc.publish("foo", b"Hello") + await nc.publish("foo", b"World") + await nc.publish("foo", b"!!!!!") # Synchronous style with iterator also supported. sub = await nc.subscribe("bar") - await nc.publish("bar", b'First') - await nc.publish("bar", b'Second') + await nc.publish("bar", b"First") + await nc.publish("bar", b"Second") try: async for msg in sub.messages: - print(f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}") + print( + f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}" + ) await sub.unsubscribe() except Exception as e: pass async def help_request(msg): print(f"Received a message on '{msg.subject} {msg.reply}': {msg.data.decode()}") - await nc.publish(msg.reply, b'I can help') + await nc.publish(msg.reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -46,7 +48,7 @@ async def help_request(msg): # Send a request and expect a single response # and trigger timeout if not faster than 500 ms. try: - response = await nc.request("help", b'help me', timeout=0.5) + response = await nc.request("help", b"help me", timeout=0.5) print(f"Received response: {response.data.decode()}") except TimeoutError: print("Request timed out") @@ -57,5 +59,6 @@ async def help_request(msg): # Terminate connection to NATS. await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/client.py b/examples/client.py index ba59287d..ba53cf69 100644 --- a/examples/client.py +++ b/examples/client.py @@ -13,8 +13,9 @@ async def message_handler(self, msg): print(f"[Received on '{msg.subject}']: {msg.data.decode()}") async def request_handler(self, msg): - print("[Request on '{} {}']: {}".format(msg.subject, msg.reply, - msg.data.decode())) + print( + "[Request on '{} {}']: {}".format(msg.subject, msg.reply, msg.data.decode()) + ) await self.nc.publish(msg.reply, b"I can help!") async def start(self): @@ -30,17 +31,16 @@ async def start(self): sub = await nc.subscribe("discover", "", self.message_handler) await sub.unsubscribe(2) - await nc.publish("discover", b'hello') - await nc.publish("discover", b'world') + await nc.publish("discover", b"hello") + await nc.publish("discover", b"world") # Following 2 messages won't be received. - await nc.publish("discover", b'again') - await nc.publish("discover", b'!!!!!') + await nc.publish("discover", b"again") + await nc.publish("discover", b"!!!!!") except ConnectionClosedError: print("Connection closed prematurely") if nc.is_connected: - # Subscription using a 'workers' queue so that only a single subscriber # gets a request at a time. await nc.subscribe("help", "workers", self.request_handler) @@ -49,7 +49,7 @@ async def start(self): # Make a request expecting a single response within 500 ms, # otherwise raising a timeout error. start_time = datetime.now() - response = await nc.request("help", b'help please', 0.500) + response = await nc.request("help", b"help please", 0.500) end_time = datetime.now() print(f"[Response]: {response.data}") print("[Duration]: {}".format(end_time - start_time)) @@ -73,6 +73,6 @@ async def start(self): print("Disconnected.") -if __name__ == '__main__': +if __name__ == "__main__": c = Client(NATS()) asyncio.run(c.start()) diff --git a/examples/clustered.py b/examples/clustered.py index 8356a117..c5227702 100644 --- a/examples/clustered.py +++ b/examples/clustered.py @@ -6,7 +6,6 @@ async def run(): - # Setup pool of servers from a NATS cluster. options = { "servers": [ @@ -65,14 +64,13 @@ async def subscribe_handler(msg): for i in range(0, max_messages): try: - await nc.publish(f"help.{i}", b'A') + await nc.publish(f"help.{i}", b"A") await nc.flush(0.500) except ErrConnectionClosed as e: print("Connection closed prematurely.") break except ErrTimeout as e: - print("Timeout occurred when publishing msg i={}: {}".format( - i, e)) + print("Timeout occurred when publishing msg i={}: {}".format(i, e)) end_time = datetime.now() await nc.drain() @@ -88,7 +86,8 @@ async def subscribe_handler(msg): if err is not None: print(f"Last Error: {err}") -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(run()) loop.close() diff --git a/examples/common/args.py b/examples/common/args.py index 71ef938e..530b9e7e 100644 --- a/examples/common/args.py +++ b/examples/common/args.py @@ -3,16 +3,19 @@ def get_args(description, epilog=""): parser = argparse.ArgumentParser(description=description, epilog=epilog) - parser.add_argument("-s", "--servers", nargs='*', - default=["nats://localhost:4222"], - help="List of servers to connect. A demo server " - "is available at nats://demo.nats.io:4222. " - "However, it is very likely that the demo server " - "will see traffic from clients other than yours." - "To avoid this, start your own locally and pass " - "its address. " - "This option is default set to nats://localhost:4222 " - "which mean that, by default, you should run your " - "own server locally" + parser.add_argument( + "-s", + "--servers", + nargs="*", + default=["nats://localhost:4222"], + help="List of servers to connect. A demo server " + "is available at nats://demo.nats.io:4222. " + "However, it is very likely that the demo server " + "will see traffic from clients other than yours." + "To avoid this, start your own locally and pass " + "its address. " + "This option is default set to nats://localhost:4222 " + "which mean that, by default, you should run your " + "own server locally", ) return parser.parse_known_args() diff --git a/examples/component.py b/examples/component.py index 0cff75a8..abe14055 100644 --- a/examples/component.py +++ b/examples/component.py @@ -3,6 +3,7 @@ import signal import nats + class Component: def __init__(self): self._nc = None @@ -21,9 +22,9 @@ async def close(self): asyncio.get_running_loop().stop() async def subscribe(self, *args, **kwargs): - if 'max_tasks' in kwargs: - sem = asyncio.Semaphore(kwargs['max_tasks']) - callback = kwargs['cb'] + if "max_tasks" in kwargs: + sem = asyncio.Semaphore(kwargs["max_tasks"]) + callback = kwargs["cb"] async def release(msg): try: @@ -35,16 +36,17 @@ async def cb(msg): await sem.acquire() asyncio.create_task(release(msg)) - kwargs['cb'] = cb - + kwargs["cb"] = cb + # Add some concurrency to the handler. - del kwargs['max_tasks'] + del kwargs["max_tasks"] return await self._nc.subscribe(*args, **kwargs) async def publish(self, *args): await self._nc.publish(*args) + async def main(): c = Component() await c.connect("nats://localhost:4222") @@ -53,7 +55,7 @@ async def handler(msg): # Cause some head of line blocking await asyncio.sleep(0.5) print(time.ctime(), f"Received on 'test': {msg.data.decode()}") - + async def wildcard_handler(msg): print(time.ctime(), f"Received on '>': {msg.data.decode()}") @@ -84,12 +86,15 @@ def signal_handler(): task.cancel() asyncio.create_task(c.close()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) await c.run_forever() -if __name__ == '__main__': + +if __name__ == "__main__": try: asyncio.run(main()) except: diff --git a/examples/connect.py b/examples/connect.py index 90bd8f29..92a45d52 100644 --- a/examples/connect.py +++ b/examples/connect.py @@ -5,32 +5,35 @@ async def main(): async def disconnected_cb(): - print('Got disconnected!') + print("Got disconnected!") async def reconnected_cb(): - print('Got reconnected to NATS...') + print("Got reconnected to NATS...") async def error_cb(e): - print(f'There was an error: {e}') + print(f"There was an error: {e}") async def closed_cb(): - print('Connection is closed') + print("Connection is closed") arguments, _ = args.get_args("Run a connect example.") - nc = await nats.connect(arguments.servers, - error_cb=error_cb, - reconnected_cb=reconnected_cb, - disconnected_cb=disconnected_cb, - closed_cb=closed_cb, - ) + nc = await nats.connect( + arguments.servers, + error_cb=error_cb, + reconnected_cb=reconnected_cb, + disconnected_cb=disconnected_cb, + closed_cb=closed_cb, + ) async def handler(msg): - print(f'Received a message on {msg.subject} {msg.reply}: {msg.data}') - await msg.respond(b'OK') - sub = await nc.subscribe('help.please', cb=handler) - - resp = await nc.request('help.please', b'help') - print('Response:', resp) - -if __name__ == '__main__': + print(f"Received a message on {msg.subject} {msg.reply}: {msg.data}") + await msg.respond(b"OK") + + sub = await nc.subscribe("help.please", cb=handler) + + resp = await nc.request("help.please", b"help") + print("Response:", resp) + + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/context-manager.py b/examples/context-manager.py index 1e86b16d..570c853e 100644 --- a/examples/context-manager.py +++ b/examples/context-manager.py @@ -4,7 +4,6 @@ async def main(): - is_done = asyncio.Future() async def closed_cb(): @@ -12,15 +11,18 @@ async def closed_cb(): is_done.set_result(True) arguments, _ = args.get_args("Run a context manager example.") - async with (await nats.connect(arguments.servers, closed_cb=closed_cb)) as nc: + async with await nats.connect(arguments.servers, closed_cb=closed_cb) as nc: print(f"Connected to NATS at {nc.connected_url.netloc}...") async def subscribe_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) await nc.subscribe("discover", cb=subscribe_handler) await nc.flush() @@ -40,5 +42,6 @@ async def subscribe_handler(msg): await asyncio.wait_for(is_done, 60.0) -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/drain-sub.py b/examples/drain-sub.py index 7cd5fab4..0201f2ff 100644 --- a/examples/drain-sub.py +++ b/examples/drain-sub.py @@ -13,25 +13,31 @@ async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # Simple publisher and async subscriber via coroutine. sub = await nc.subscribe("foo", cb=message_handler) # Stop receiving after 2 messages. await sub.unsubscribe(2) - await nc.publish("foo", b'Hello') - await nc.publish("foo", b'World') - await nc.publish("foo", b'!!!!!') + await nc.publish("foo", b"Hello") + await nc.publish("foo", b"World") + await nc.publish("foo", b"!!!!!") async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await nc.publish(reply, b'I can help') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -52,7 +58,7 @@ async def drain_sub(): # Send multiple requests and drain the subscription. requests = [] for i in range(0, 100): - request = nc.request("help", b'help me', 0.5) + request = nc.request("help", b"help me", 0.5) requests.append(request) # Wait for all the responses @@ -62,13 +68,13 @@ async def drain_sub(): print("Received {count} responses!".format(count=len(responses))) for response in responses[:5]: - print("Received response: {message}".format( - message=response.data.decode())) + print("Received response: {message}".format(message=response.data.decode())) except: pass # Terminate connection to NATS. await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/example.py b/examples/example.py index 91020a1c..e1b02ef4 100644 --- a/examples/example.py +++ b/examples/example.py @@ -22,22 +22,22 @@ async def message_handler(msg): sub = await nc.subscribe("discover", "", message_handler) await sub.unsubscribe(2) - await nc.publish("discover", b'hello') - await nc.publish("discover", b'world') + await nc.publish("discover", b"hello") + await nc.publish("discover", b"world") # Following 2 messages won't be received. - await nc.publish("discover", b'again') - await nc.publish("discover", b'!!!!!') + await nc.publish("discover", b"again") + await nc.publish("discover", b"!!!!!") except ConnectionClosedError: print("Connection closed prematurely") async def request_handler(msg): - print("[Request on '{} {}']: {}".format(msg.subject, msg.reply, - msg.data.decode())) - await nc.publish(msg.reply, b'OK') + print( + "[Request on '{} {}']: {}".format(msg.subject, msg.reply, msg.data.decode()) + ) + await nc.publish(msg.reply, b"OK") if nc.is_connected: - # Subscription using a 'workers' queue so that only a single subscriber # gets a request at a time. await nc.subscribe("help", "workers", cb=request_handler) @@ -45,7 +45,7 @@ async def request_handler(msg): try: # Make a request expecting a single response within 500 ms, # otherwise raising a timeout error. - msg = await nc.request("help", b'help please', 0.500) + msg = await nc.request("help", b"help please", 0.500) print(f"[Response]: {msg.data}") # Make a roundtrip to the server to ensure messages @@ -67,5 +67,5 @@ async def request_handler(msg): print("Disconnected.") -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/jetstream.py b/examples/jetstream.py index ca24a78b..3ee43be8 100644 --- a/examples/jetstream.py +++ b/examples/jetstream.py @@ -43,6 +43,7 @@ async def qsub_a(msg): async def qsub_b(msg): print("QSUB B:", msg) await msg.ack() + await js.subscribe("foo", "workers", cb=qsub_a) await js.subscribe("foo", "workers", cb=qsub_b) @@ -65,5 +66,6 @@ async def qsub_b(msg): await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/kv.py b/examples/kv.py index faae5330..e8b54e86 100644 --- a/examples/kv.py +++ b/examples/kv.py @@ -7,25 +7,26 @@ async def main(): js = nc.jetstream() # Create a KV - kv = await js.create_key_value(bucket='MY_KV') + kv = await js.create_key_value(bucket="MY_KV") # Set and retrieve a value - await kv.put('hello', b'world') - await kv.put('goodbye', b'farewell') - await kv.put('greetings', b'hi') - await kv.put('greeting', b'hey') + await kv.put("hello", b"world") + await kv.put("goodbye", b"farewell") + await kv.put("greetings", b"hi") + await kv.put("greeting", b"hey") # Retrieve and print the value of 'hello' - entry = await kv.get('hello') - print(f'KeyValue.Entry: key={entry.key}, value={entry.value}') + entry = await kv.get("hello") + print(f"KeyValue.Entry: key={entry.key}, value={entry.value}") # KeyValue.Entry: key=hello, value=world # Retrieve keys with filters - filtered_keys = await kv.keys(filters=['hello', 'greet']) - print(f'Filtered Keys: {filtered_keys}') + filtered_keys = await kv.keys(filters=["hello", "greet"]) + print(f"Filtered Keys: {filtered_keys}") # Expected Output: ['hello', 'greetings'] await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/nats-pub/__main__.py b/examples/nats-pub/__main__.py index 42432556..e57823ef 100644 --- a/examples/nats-pub/__main__.py +++ b/examples/nats-pub/__main__.py @@ -38,11 +38,11 @@ async def run(): parser = argparse.ArgumentParser() # e.g. nats-pub -s demo.nats.io hello "world" - parser.add_argument('subject', default='hello', nargs='?') - parser.add_argument('-d', '--data', default="hello world") - parser.add_argument('-s', '--servers', default="") - parser.add_argument('--creds', default="") - parser.add_argument('--token', default="") + parser.add_argument("subject", default="hello", nargs="?") + parser.add_argument("-d", "--data", default="hello world") + parser.add_argument("-s", "--servers", default="") + parser.add_argument("--creds", default="") + parser.add_argument("--token", default="") args, unknown = parser.parse_known_args() data = args.data @@ -55,10 +55,7 @@ async def error_cb(e): async def reconnected_cb(): print("Got reconnected to NATS...") - options = { - "error_cb": error_cb, - "reconnected_cb": reconnected_cb - } + options = {"error_cb": error_cb, "reconnected_cb": reconnected_cb} if len(args.creds) > 0: options["user_credentials"] = args.creds @@ -68,7 +65,7 @@ async def reconnected_cb(): try: if len(args.servers) > 0: - options['servers'] = args.servers + options["servers"] = args.servers nc = await nats.connect(**options) except Exception as e: @@ -80,7 +77,8 @@ async def reconnected_cb(): await nc.flush() await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() try: loop.run_until_complete(run()) diff --git a/examples/nats-req/__main__.py b/examples/nats-req/__main__.py index d1d59432..dd69806e 100644 --- a/examples/nats-req/__main__.py +++ b/examples/nats-req/__main__.py @@ -38,10 +38,10 @@ async def run(): parser = argparse.ArgumentParser() # e.g. nats-req -s demo.nats.io hello world - parser.add_argument('subject', default='hello', nargs='?') - parser.add_argument('-d', '--data', default="hello world") - parser.add_argument('-s', '--servers', default="") - parser.add_argument('--creds', default="") + parser.add_argument("subject", default="hello", nargs="?") + parser.add_argument("-d", "--data", default="hello world") + parser.add_argument("-s", "--servers", default="") + parser.add_argument("--creds", default="") args, unknown = parser.parse_known_args() data = args.data @@ -62,7 +62,7 @@ async def reconnected_cb(): nc = None try: if len(args.servers) > 0: - options['servers'] = args.servers + options["servers"] = args.servers nc = await nats.connect(**options) except Exception as e: @@ -70,7 +70,7 @@ async def reconnected_cb(): show_usage_and_die() async def req_callback(msg): - await msg.respond(b'a response') + await msg.respond(b"a response") await nc.subscribe(args.subject, cb=req_callback) @@ -87,7 +87,7 @@ async def req_callback(msg): await nc.drain() -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() try: loop.run_until_complete(run()) diff --git a/examples/nats-sub/__main__.py b/examples/nats-sub/__main__.py index 4ae7136e..13510c8f 100644 --- a/examples/nats-sub/__main__.py +++ b/examples/nats-sub/__main__.py @@ -39,10 +39,10 @@ async def run(): parser = argparse.ArgumentParser() # e.g. nats-sub hello -s nats://127.0.0.1:4222 - parser.add_argument('subject', default='hello', nargs='?') - parser.add_argument('-s', '--servers', default="") - parser.add_argument('-q', '--queue', default="") - parser.add_argument('--creds', default="") + parser.add_argument("subject", default="hello", nargs="?") + parser.add_argument("-s", "--servers", default="") + parser.add_argument("-q", "--queue", default="") + parser.add_argument("--creds", default="") args = parser.parse_args() async def error_cb(e): @@ -69,7 +69,7 @@ async def subscribe_handler(msg): options = { "error_cb": error_cb, "closed_cb": closed_cb, - "reconnected_cb": reconnected_cb + "reconnected_cb": reconnected_cb, } if len(args.creds) > 0: @@ -78,7 +78,7 @@ async def subscribe_handler(msg): nc = None try: if len(args.servers) > 0: - options['servers'] = args.servers + options["servers"] = args.servers nc = await nats.connect(**options) except Exception as e: @@ -92,13 +92,15 @@ def signal_handler(): return asyncio.create_task(nc.drain()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) await nc.subscribe(args.subject, args.queue, subscribe_handler) -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(run()) try: diff --git a/examples/publish.py b/examples/publish.py index b8cc3f21..ce933c57 100644 --- a/examples/publish.py +++ b/examples/publish.py @@ -4,8 +4,9 @@ async def main(): - arguments, _ = args.get_args("Run a publish example.", - "Usage: python examples/publish.py") + arguments, _ = args.get_args( + "Run a publish example.", "Usage: python examples/publish.py" + ) nc = await nats.connect(arguments.servers) # Publish as message with an inbox. @@ -13,13 +14,13 @@ async def main(): sub = await nc.subscribe("hello") # Simple publishing - await nc.publish("hello", b'Hello World!') + await nc.publish("hello", b"Hello World!") # Publish with a reply - await nc.publish("hello", b'Hello World!', reply=inbox) - + await nc.publish("hello", b"Hello World!", reply=inbox) + # Publish with a reply - await nc.publish("hello", b'With Headers', headers={'Foo':'Bar'}) + await nc.publish("hello", b"With Headers", headers={"Foo": "Bar"}) while True: try: @@ -32,5 +33,6 @@ async def main(): print("Data :", msg.data) print("Headers:", msg.header) -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/service.py b/examples/service.py index 70bfe79f..56a7c5f4 100644 --- a/examples/service.py +++ b/examples/service.py @@ -17,8 +17,10 @@ def signal_handler(): asyncio.create_task(nc.close()) asyncio.create_task(stop()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) async def disconnected_cb(): print("Got disconnected...") @@ -26,18 +28,23 @@ async def disconnected_cb(): async def reconnected_cb(): print("Got reconnected...") - await nc.connect("127.0.0.1", - reconnected_cb=reconnected_cb, - disconnected_cb=disconnected_cb, - max_reconnect_attempts=-1) + await nc.connect( + "127.0.0.1", + reconnected_cb=reconnected_cb, + disconnected_cb=disconnected_cb, + max_reconnect_attempts=-1, + ) async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await nc.publish(reply, b'I can help') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -47,12 +54,13 @@ async def help_request(msg): for i in range(1, 1000000): await asyncio.sleep(1) try: - response = await nc.request("help", b'hi') + response = await nc.request("help", b"hi") print(response) except Exception as e: print("Error:", e) -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() try: loop.run_until_complete(main()) diff --git a/examples/simple.py b/examples/simple.py index 220ef1f3..ee7c71cb 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -10,7 +10,7 @@ async def main(): sub = await nc.subscribe("foo") # Publish a message to 'foo' - await nc.publish("foo", b'Hello from Python!') + await nc.publish("foo", b"Hello from Python!") # Process a message msg = await sub.next_msg() @@ -19,5 +19,6 @@ async def main(): # Close NATS connection await nc.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/subscribe.py b/examples/subscribe.py index 64b0fe60..b98ec162 100644 --- a/examples/subscribe.py +++ b/examples/subscribe.py @@ -13,10 +13,7 @@ async def closed_cb(): asyncio.get_running_loop().stop() arguments, _ = args.get_args("Run a subscription example.") - options = { - "servers": arguments.servers, - "closed_cb": closed_cb - } + options = {"servers": arguments.servers, "closed_cb": closed_cb} await nc.connect(**options) print(f"Connected to NATS at {nc.connected_url.netloc}...") @@ -25,9 +22,12 @@ async def subscribe_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await msg.respond(b'I can help!') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await msg.respond(b"I can help!") # Basic subscription to receive all published messages # which are being sent to a single topic 'discover' @@ -43,14 +43,17 @@ def signal_handler(): print("Disconnecting...") asyncio.create_task(nc.close()) - for sig in ('SIGINT', 'SIGTERM'): - asyncio.get_running_loop().add_signal_handler(getattr(signal, sig), signal_handler) + for sig in ("SIGINT", "SIGTERM"): + asyncio.get_running_loop().add_signal_handler( + getattr(signal, sig), signal_handler + ) - await nc.request("help", b'help') + await nc.request("help", b"help") await asyncio.sleep(30) -if __name__ == '__main__': + +if __name__ == "__main__": try: asyncio.run(main()) except: diff --git a/examples/tls.py b/examples/tls.py index 26762988..669e9361 100644 --- a/examples/tls.py +++ b/examples/tls.py @@ -10,35 +10,42 @@ async def main(): # Your local server needs to configured with the server certificate in the ../tests/certs directory. ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) ssl_ctx.minimum_version = ssl.PROTOCOL_TLSv1_2 - ssl_ctx.load_verify_locations('../tests/certs/ca.pem') + ssl_ctx.load_verify_locations("../tests/certs/ca.pem") ssl_ctx.load_cert_chain( - certfile='../tests/certs/client-cert.pem', - keyfile='../tests/certs/client-key.pem') + certfile="../tests/certs/client-cert.pem", + keyfile="../tests/certs/client-key.pem", + ) await nc.connect(servers=["nats://127.0.0.1:4222"], tls=ssl_ctx) async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # Simple publisher and async subscriber via coroutine. sid = await nc.subscribe("foo", cb=message_handler) # Stop receiving after 2 messages. await nc.auto_unsubscribe(sid, 2) - await nc.publish("foo", b'Hello') - await nc.publish("foo", b'World') - await nc.publish("foo", b'!!!!!') + await nc.publish("foo", b"Hello") + await nc.publish("foo", b"World") + await nc.publish("foo", b"!!!!!") async def help_request(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) - await nc.publish(reply, b'I can help') + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) + await nc.publish(reply, b"I can help") # Use queue named 'workers' for distributing requests # among subscribers. @@ -47,9 +54,8 @@ async def help_request(msg): # Send a request and expect a single response # and trigger timeout if not faster than 50 ms. try: - response = await nc.timed_request("help", b'help me', 0.050) - print("Received response: {message}".format( - message=response.data.decode())) + response = await nc.timed_request("help", b"help me", 0.050) + print("Received response: {message}".format(message=response.data.decode())) except TimeoutError: print("Request timed out") @@ -57,7 +63,7 @@ async def help_request(msg): await nc.close() -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(main()) loop.close() diff --git a/examples/wildcard.py b/examples/wildcard.py index e0c308a0..a82e7109 100644 --- a/examples/wildcard.py +++ b/examples/wildcard.py @@ -6,16 +6,20 @@ async def run(loop): nc = NATS() - arguments, _ = args.get_args("Run the wildcard example.", - "Usage: python examples/wildcard.py") + arguments, _ = args.get_args( + "Run the wildcard example.", "Usage: python examples/wildcard.py" + ) await nc.connect(arguments.servers) async def message_handler(msg): subject = msg.subject reply = msg.reply data = msg.data.decode() - print("Received a message on '{subject} {reply}': {data}".format( - subject=subject, reply=reply, data=data)) + print( + "Received a message on '{subject} {reply}': {data}".format( + subject=subject, reply=reply, data=data + ) + ) # "*" matches any token, at any level of the subject. await nc.subscribe("foo.*.baz", cb=message_handler) @@ -26,12 +30,13 @@ async def message_handler(msg): await nc.subscribe("foo.>", cb=message_handler) # Matches all of the above. - await nc.publish("foo.bar.baz", b'Hello World') + await nc.publish("foo.bar.baz", b"Hello World") # Gracefully close the connection. await nc.drain() -if __name__ == '__main__': + +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(run(loop)) loop.close() diff --git a/nats/__init__.py b/nats/__init__.py index 32d01162..cde97347 100644 --- a/nats/__init__.py +++ b/nats/__init__.py @@ -18,8 +18,7 @@ async def connect( - servers: Union[str, List[str]] = ["nats://localhost:4222"], - **options + servers: Union[str, List[str]] = ["nats://localhost:4222"], **options ) -> NATS: """ :param servers: List of servers to connect. diff --git a/nats/aio/client.py b/nats/aio/client.py index 76fa0480..15d83185 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -69,26 +69,26 @@ ) from .transport import TcpTransport, Transport, WebSocketTransport -__version__ = '2.9.0' -__lang__ = 'python3' +__version__ = "2.9.0" +__lang__ = "python3" _logger = logging.getLogger(__name__) PROTOCOL = 1 -INFO_OP = b'INFO' -CONNECT_OP = b'CONNECT' -PING_OP = b'PING' -PONG_OP = b'PONG' -OK_OP = b'+OK' -ERR_OP = b'-ERR' -_CRLF_ = b'\r\n' +INFO_OP = b"INFO" +CONNECT_OP = b"CONNECT" +PING_OP = b"PING" +PONG_OP = b"PONG" +OK_OP = b"+OK" +ERR_OP = b"-ERR" +_CRLF_ = b"\r\n" _CRLF_LEN_ = len(_CRLF_) -_SPC_ = b' ' +_SPC_ = b" " _SPC_BYTE_ = 32 EMPTY = "" PING_PROTO = PING_OP + _CRLF_ PONG_PROTO = PONG_OP + _CRLF_ -DEFAULT_INBOX_PREFIX = b'_INBOX' +DEFAULT_INBOX_PREFIX = b"_INBOX" DEFAULT_PENDING_SIZE = 2 * 1024 * 1024 DEFAULT_BUFFER_SIZE = 32768 @@ -103,7 +103,7 @@ DEFAULT_DRAIN_TIMEOUT = 30 # in seconds MAX_CONTROL_LINE_SIZE = 1024 -NATS_HDR_LINE = bytearray(b'NATS/1.0') +NATS_HDR_LINE = bytearray(b"NATS/1.0") NATS_HDR_LINE_SIZE = len(NATS_HDR_LINE) NO_RESPONDERS_STATUS = "503" CTRL_STATUS = "100" @@ -138,7 +138,6 @@ class Srv: class ServerVersion: - def __init__(self, server_version: str) -> None: self._server_version = server_version self._major_version: Optional[int] = None @@ -148,10 +147,10 @@ def __init__(self, server_version: str) -> None: # TODO(@orsinium): use cached_property def parse_version(self) -> None: - v = (self._server_version).split('-') + v = (self._server_version).split("-") if len(v) > 1: self._dev_version = v[1] - tokens = v[0].split('.') + tokens = v[0].split(".") n = len(tokens) if n > 1: self._major_version = int(tokens[0]) @@ -182,7 +181,7 @@ def patch(self) -> int: def dev(self) -> str: if not self._dev_version: self.parse_version() - return self._dev_version or '' + return self._dev_version or "" def __repr__(self) -> str: return f"" @@ -193,7 +192,7 @@ async def _default_error_callback(ex: Exception) -> None: Provides a default way to handle async errors if the user does not provide one. """ - _logger.error('nats: encountered error', exc_info=ex) + _logger.error("nats: encountered error", exc_info=ex) class Client: @@ -287,12 +286,12 @@ def __init__(self) -> None: self.options: Dict[str, Any] = {} self.stats = { - 'in_msgs': 0, - 'out_msgs': 0, - 'in_bytes': 0, - 'out_bytes': 0, - 'reconnects': 0, - 'errors_received': 0, + "in_msgs": 0, + "out_msgs": 0, + "in_bytes": 0, + "out_bytes": 0, + "reconnects": 0, + "errors_received": 0, } async def connect( @@ -420,8 +419,13 @@ async def subscribe_handler(msg): """ - for cb in [error_cb, disconnected_cb, closed_cb, reconnected_cb, - discovered_server_cb]: + for cb in [ + error_cb, + disconnected_cb, + closed_cb, + reconnected_cb, + discovered_server_cb, + ]: if cb and not asyncio.iscoroutinefunction(cb): raise errors.InvalidCallbackTypeError @@ -461,12 +465,12 @@ async def subscribe_handler(msg): self.options["token"] = token self.options["connect_timeout"] = connect_timeout self.options["drain_timeout"] = drain_timeout - self.options['tls_handshake_first'] = tls_handshake_first + self.options["tls_handshake_first"] = tls_handshake_first if tls: - self.options['tls'] = tls + self.options["tls"] = tls if tls_hostname: - self.options['tls_hostname'] = tls_hostname + self.options["tls_hostname"] = tls_hostname # Check if the username or password was set in the server URI server_auth_configured = False @@ -502,7 +506,9 @@ async def subscribe_handler(msg): try: await self._select_next_server() await self._process_connect_init() - assert self._current_server, "the current server must be set by _select_next_server" + assert ( + self._current_server + ), "the current server must be set by _select_next_server" self._current_server.reconnects = 0 break except errors.NoServersError as e: @@ -542,7 +548,7 @@ def _setup_nkeys_jwt_connect(self) -> None: def user_cb() -> bytearray: contents = None - with open(creds[0], 'rb') as f: + with open(creds[0], "rb") as f: contents = bytearray(os.fstat(f.fileno()).st_size) f.readinto(contents) # type: ignore[attr-defined] return contents @@ -551,7 +557,7 @@ def user_cb() -> bytearray: def sig_cb(nonce: str) -> bytes: seed = None - with open(creds[1], 'rb') as f: + with open(creds[1], "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) # type: ignore[attr-defined] kp = nkeys.from_seed(seed) @@ -587,7 +593,6 @@ def sig_cb(nonce: str) -> bytes: self._signature_cb = sig_cb def _read_creds_user_nkey(self, creds: str | UserString) -> bytearray: - def get_user_seed(f): for line in f: # Detect line where the NKEY would start and end, @@ -616,7 +621,6 @@ def get_user_seed(f): return get_user_seed(f) def _read_creds_user_jwt(self, creds: str | RawCredentials): - def get_user_jwt(f): user_jwt = None while True: @@ -625,7 +629,7 @@ def get_user_jwt(f): user_jwt = bytearray(f.readline()) break # Remove trailing line break but reusing same memory view. - return user_jwt[:len(user_jwt) - 1] + return user_jwt[: len(user_jwt) - 1] if isinstance(creds, UserString): return get_user_jwt(BytesIO(creds.data.encode())) @@ -634,7 +638,9 @@ def get_user_jwt(f): return get_user_jwt(f) def _setup_nkeys_seed_connect(self) -> None: - assert self._nkeys_seed or self._nkeys_seed_str, "Client.connect must be called first" + assert ( + self._nkeys_seed or self._nkeys_seed_str + ), "Client.connect must be called first" import os import nkeys @@ -645,7 +651,7 @@ def _get_nkeys_seed() -> nkeys.KeyPair: seed = bytearray(self._nkeys_seed_str.encode()) else: creds = self._nkeys_seed - with open(creds, 'rb') as f: + with open(creds, "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) # type: ignore[attr-defined] key_pair = nkeys.from_seed(seed) @@ -686,25 +692,26 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: # Kick the flusher once again so that Task breaks and avoid pending futures. await self._flush_pending() - if self._reading_task is not None and not self._reading_task.cancelled( - ): + if self._reading_task is not None and not self._reading_task.cancelled(): self._reading_task.cancel() - if self._ping_interval_task is not None and not self._ping_interval_task.cancelled( + if ( + self._ping_interval_task is not None + and not self._ping_interval_task.cancelled() ): self._ping_interval_task.cancel() - if self._flusher_task is not None and not self._flusher_task.cancelled( - ): + if self._flusher_task is not None and not self._flusher_task.cancelled(): self._flusher_task.cancel() - if self._reconnection_task is not None and not self._reconnection_task.done( - ): + if self._reconnection_task is not None and not self._reconnection_task.done(): self._reconnection_task.cancel() # Wait for the reconnection task to be done which should be soon. try: - if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled( + if ( + self._reconnection_task_future is not None + and not self._reconnection_task_future.cancelled() ): await asyncio.wait_for( self._reconnection_task_future, @@ -788,9 +795,7 @@ async def drain(self) -> None: self._status = Client.DRAINING_SUBS try: - await asyncio.wait_for( - drain_is_done, self.options["drain_timeout"] - ) + await asyncio.wait_for(drain_is_done, self.options["drain_timeout"]) except asyncio.TimeoutError: drain_is_done.exception() drain_is_done.cancel() @@ -805,9 +810,9 @@ async def drain(self) -> None: async def publish( self, subject: str, - payload: bytes = b'', - reply: str = '', - headers: Optional[Dict[str, str]] = None + payload: bytes = b"", + reply: str = "", + headers: Optional[Dict[str, str]] = None, ) -> None: """ Publishes a NATS message. @@ -861,16 +866,17 @@ async def main(): payload_size = len(payload) if not self.is_connected: - if self._max_pending_size <= 0 or payload_size + self._pending_data_size > self._max_pending_size: + if ( + self._max_pending_size <= 0 + or payload_size + self._pending_data_size > self._max_pending_size + ): # Cannot publish during a reconnection when the buffering is disabled, # or if pending buffer is already full. raise errors.OutboundBufferLimitError if payload_size > self._max_payload: raise errors.MaxPayloadError - await self._send_publish( - subject, reply, payload, payload_size, headers - ) + await self._send_publish(subject, reply, payload, payload_size, headers) async def _send_publish( self, @@ -900,15 +906,15 @@ async def _send_publish( # Skip empty keys continue hdr.extend(key.encode()) - hdr.extend(b': ') + hdr.extend(b": ") value = v.strip() hdr.extend(value.encode()) hdr.extend(_CRLF_) hdr.extend(_CRLF_) pub_cmd = prot_command.hpub_cmd(subject, reply, hdr, payload) - self.stats['out_msgs'] += 1 - self.stats['out_bytes'] += payload_size + self.stats["out_msgs"] += 1 + self.stats["out_bytes"] += payload_size await self._send_command(pub_cmd) if self._flush_queue is not None and self._flush_queue.empty(): await self._flush_pending() @@ -931,10 +937,10 @@ async def subscribe( If a callback isn't provided, messages can be retrieved via an asynchronous iterator on the returned subscription object. """ - if not subject or (' ' in subject): + if not subject or (" " in subject): raise errors.BadSubjectError - if queue and (' ' in queue): + if queue and (" " in queue): raise errors.BadSubjectError if self.is_closed: @@ -979,17 +985,15 @@ async def _init_request_sub(self) -> None: self._resp_map = {} self._resp_sub_prefix = self._inbox_prefix[:] - self._resp_sub_prefix.extend(b'.') + self._resp_sub_prefix.extend(b".") self._resp_sub_prefix.extend(self._nuid.next()) - self._resp_sub_prefix.extend(b'.') + self._resp_sub_prefix.extend(b".") resp_mux_subject = self._resp_sub_prefix[:] - resp_mux_subject.extend(b'*') - await self.subscribe( - resp_mux_subject.decode(), cb=self._request_sub_callback - ) + resp_mux_subject.extend(b"*") + await self.subscribe(resp_mux_subject.decode(), cb=self._request_sub_callback) async def _request_sub_callback(self, msg: Msg) -> None: - token = msg.subject[len(self._inbox_prefix) + 22 + 2:] + token = msg.subject[len(self._inbox_prefix) + 22 + 2 :] future = self._resp_map.get(token) if not future: @@ -1001,7 +1005,7 @@ async def _request_sub_callback(self, msg: Msg) -> None: async def request( self, subject: str, - payload: bytes = b'', + payload: bytes = b"", timeout: float = 0.5, old_style: bool = False, headers: Optional[Dict[str, Any]] = None, @@ -1014,15 +1018,15 @@ async def request( """ if old_style: # FIXME: Support headers in old style requests. - return await self._request_old_style( - subject, payload, timeout=timeout - ) + return await self._request_old_style(subject, payload, timeout=timeout) else: msg = await self._request_new_style( subject, payload, timeout=timeout, headers=headers ) - if msg.headers and msg.headers.get(nats.js.api.Header.STATUS - ) == NO_RESPONDERS_STATUS: + if ( + msg.headers + and msg.headers.get(nats.js.api.Header.STATUS) == NO_RESPONDERS_STATUS + ): raise errors.NoRespondersError return msg @@ -1052,9 +1056,7 @@ async def _request_new_style( self._resp_map[token.decode()] = future # Publish the request - await self.publish( - subject, payload, reply=inbox.decode(), headers=headers - ) + await self.publish(subject, payload, reply=inbox.decode(), headers=headers) # Wait for the response or give up on timeout. try: @@ -1074,7 +1076,7 @@ def new_inbox(self) -> str: msg = sub.next_msg() """ next_inbox = self._inbox_prefix[:] - next_inbox.extend(b'.') + next_inbox.extend(b".") next_inbox.extend(self._nuid.next()) return next_inbox.decode() @@ -1097,8 +1099,7 @@ async def _request_old_style( try: msg = await asyncio.wait_for(future, timeout) if msg.headers: - if msg.headers.get(nats.js.api.Header.STATUS - ) == NO_RESPONDERS_STATUS: + if msg.headers.get(nats.js.api.Header.STATUS) == NO_RESPONDERS_STATUS: raise errors.NoRespondersError return msg except asyncio.TimeoutError: @@ -1198,8 +1199,7 @@ def is_connecting(self) -> bool: @property def is_draining(self) -> bool: return ( - self._status == Client.DRAINING_SUBS - or self._status == Client.DRAINING_PUBS + self._status == Client.DRAINING_SUBS or self._status == Client.DRAINING_PUBS ) @property @@ -1220,11 +1220,11 @@ def connected_server_version(self) -> ServerVersion: def ssl_context(self) -> ssl.SSLContext: ssl_context: Optional[ssl.SSLContext] = None if "tls" in self.options: - ssl_context = self.options.get('tls') + ssl_context = self.options.get("tls") else: ssl_context = ssl.create_default_context() if ssl_context is None: - raise errors.Error('nats: no ssl context provided') + raise errors.Error("nats: no ssl context provided") return ssl_context async def _send_command(self, cmd: bytes, priority: bool = False) -> None: @@ -1233,7 +1233,10 @@ async def _send_command(self, cmd: bytes, priority: bool = False) -> None: else: self._pending.append(cmd) self._pending_data_size += len(cmd) - if self._max_pending_size > 0 and self._pending_data_size > self._max_pending_size: + if ( + self._max_pending_size > 0 + and self._pending_data_size > self._max_pending_size + ): # Only flush force timeout on publish await self._flush_pending(force_flush=True) @@ -1298,10 +1301,14 @@ def _setup_server_pool(self, connect_url: Union[List[str]]) -> None: except ValueError: raise errors.Error("nats: invalid connect url option") # make sure protocols aren't mixed - if not (all(server.uri.scheme in ("nats", "tls") - for server in self._server_pool) - or all(server.uri.scheme in ("ws", "wss") - for server in self._server_pool)): + if not ( + all( + server.uri.scheme in ("nats", "tls") for server in self._server_pool + ) + or all( + server.uri.scheme in ("ws", "wss") for server in self._server_pool + ) + ): raise errors.Error( "nats: mixing of websocket and non websocket URLs is not allowed" ) @@ -1329,8 +1336,10 @@ async def _select_next_server(self) -> None: # Not yet exceeded max_reconnect_attempts so can still use # this server in the future. self._server_pool.append(s) - if s.last_attempt is not None and now < s.last_attempt + self.options[ - "reconnect_time_wait"]: + if ( + s.last_attempt is not None + and now < s.last_attempt + self.options["reconnect_time_wait"] + ): # Backoff connecting to server if we attempted recently. await asyncio.sleep(self.options["reconnect_time_wait"]) try: @@ -1347,13 +1356,13 @@ async def _select_next_server(self) -> None: s.uri, ssl_context=self.ssl_context, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options["connect_timeout"], ) else: await self._transport.connect( s.uri, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options["connect_timeout"], ) self._current_server = s break @@ -1409,7 +1418,9 @@ async def _process_op_err(self, e: Exception) -> None: self._status = Client.RECONNECTING self._ps.reset() - if self._reconnection_task is not None and not self._reconnection_task.cancelled( + if ( + self._reconnection_task is not None + and not self._reconnection_task.cancelled() ): # Cancel the previous task in case it may still be running. self._reconnection_task.cancel() @@ -1424,16 +1435,16 @@ async def _process_op_err(self, e: Exception) -> None: async def _attempt_reconnect(self) -> None: assert self._current_server, "Client.connect must be called first" - if self._reading_task is not None and not self._reading_task.cancelled( - ): + if self._reading_task is not None and not self._reading_task.cancelled(): self._reading_task.cancel() - if self._ping_interval_task is not None and not self._ping_interval_task.cancelled( + if ( + self._ping_interval_task is not None + and not self._ping_interval_task.cancelled() ): self._ping_interval_task.cancel() - if self._flusher_task is not None and not self._flusher_task.cancelled( - ): + if self._flusher_task is not None and not self._flusher_task.cancelled(): self._flusher_task.cancel() if self._transport is not None: @@ -1450,8 +1461,7 @@ async def _attempt_reconnect(self) -> None: if self.is_closed: return - if "dont_randomize" not in self.options or not self.options[ - "dont_randomize"]: + if "dont_randomize" not in self.options or not self.options["dont_randomize"]: shuffle(self._server_pool) # Create a future that the client can use to control waiting @@ -1487,9 +1497,7 @@ async def _attempt_reconnect(self) -> None: # auto unsubscribe the number of messages we have left max_msgs = sub._max_msgs - sub._received - sub_cmd = prot_command.sub_cmd( - sub._subject, sub._queue, sid - ) + sub_cmd = prot_command.sub_cmd(sub._subject, sub._queue, sid) self._transport.write(sub_cmd) if max_msgs > 0: @@ -1525,24 +1533,26 @@ async def _attempt_reconnect(self) -> None: except asyncio.CancelledError: break - if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled( + if ( + self._reconnection_task_future is not None + and not self._reconnection_task_future.cancelled() ): self._reconnection_task_future.set_result(True) def _connect_command(self) -> bytes: - ''' + """ Generates a JSON string with the params to be used when sending CONNECT to the server. ->> CONNECT {"lang": "python3"} - ''' + """ options = { "verbose": self.options["verbose"], "pedantic": self.options["pedantic"], "lang": __lang__, "version": __version__, - "protocol": PROTOCOL + "protocol": PROTOCOL, } if "headers" in self._server_info: options["headers"] = self._server_info["headers"] @@ -1560,8 +1570,10 @@ def _connect_command(self) -> bytes: options["nkey"] = self._public_nkey # In case there is no password, then consider handle # sending a token instead. - elif self.options["user"] is not None and self.options[ - "password"] is not None: + elif ( + self.options["user"] is not None + and self.options["password"] is not None + ): options["user"] = self.options["user"] options["pass"] = self.options["password"] elif self.options["token"] is not None: @@ -1579,7 +1591,7 @@ def _connect_command(self) -> bytes: options["echo"] = not self.options["no_echo"] connect_opts = json.dumps(options, sort_keys=True) - return b''.join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_]) + return b"".join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_]) async def _process_ping(self) -> None: """ @@ -1598,8 +1610,7 @@ async def _process_pong(self) -> None: self._pongs_received += 1 self._pings_outstanding = 0 - def _is_control_message(self, data, header: Dict[str, - str]) -> Optional[str]: + def _is_control_message(self, data, header: Dict[str, str]) -> Optional[str]: if len(data) > 0: return None status = header.get(nats.js.api.Header.STATUS) @@ -1628,9 +1639,9 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # if raw_headers[0] == _SPC_BYTE_: # Special handling for status messages. - line = headers[len(NATS_HDR_LINE) + 1:] + line = headers[len(NATS_HDR_LINE) + 1 :] status = line[:STATUS_MSG_LEN] - desc = line[STATUS_MSG_LEN + 1:len(line) - _CRLF_LEN_ - _CRLF_LEN_] + desc = line[STATUS_MSG_LEN + 1 : len(line) - _CRLF_LEN_ - _CRLF_LEN_] stripped_status = status.strip().decode() # Process as status only when it is a valid integer. @@ -1640,7 +1651,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # Move the raw_headers to end of line i = raw_headers.find(_CRLF_) - raw_headers = raw_headers[i + _CRLF_LEN_:] + raw_headers = raw_headers[i + _CRLF_LEN_ :] if len(desc) > 0: # Heartbeat messages can have both headers and inline status, @@ -1648,9 +1659,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: i = desc.find(_CRLF_) if i > 0: hdr[nats.js.api.Header.DESCRIPTION] = desc[:i].decode() - parsed_hdr = self._hdr_parser.parsebytes( - desc[i + _CRLF_LEN_:] - ) + parsed_hdr = self._hdr_parser.parsebytes(desc[i + _CRLF_LEN_ :]) for k, v in parsed_hdr.items(): hdr[k] = v else: @@ -1665,15 +1674,14 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # # NATS/1.0\r\nfoo: bar\r\nhello: world # - raw_headers = headers[NATS_HDR_LINE_SIZE + _CRLF_LEN_:] + raw_headers = headers[NATS_HDR_LINE_SIZE + _CRLF_LEN_ :] try: if parse_email: parsed_hdr = parse_email(raw_headers).headers else: parsed_hdr = { k.strip(): v.strip() - for k, v in self._hdr_parser.parsebytes(raw_headers - ).items() + for k, v in self._hdr_parser.parsebytes(raw_headers).items() } if hdr: hdr.update(parsed_hdr) @@ -1706,8 +1714,8 @@ async def _process_msg( Process MSG sent by server. """ payload_size = len(data) - self.stats['in_msgs'] += 1 - self.stats['in_bytes'] += payload_size + self.stats["in_msgs"] += 1 + self.stats["in_bytes"] += payload_size sub = self._subs.get(sid) if not sub: @@ -1780,17 +1788,17 @@ async def _process_msg( try: sub._pending_size += payload_size # allow setting pending_bytes_limit to 0 to disable - if sub._pending_bytes_limit > 0 and sub._pending_size >= sub._pending_bytes_limit: + if ( + sub._pending_bytes_limit > 0 + and sub._pending_size >= sub._pending_bytes_limit + ): # Subtract the bytes since the message will be thrown away # so it would not be pending data. sub._pending_size -= payload_size await self._error_cb( errors.SlowConsumerError( - subject=msg.subject, - reply=msg.reply, - sid=sid, - sub=sub + subject=msg.subject, reply=msg.reply, sid=sid, sub=sub ) ) return @@ -1859,23 +1867,26 @@ def _process_info( with latest updates from cluster to enable server discovery. """ assert self._current_server, "Client.connect must be called first" - if 'connect_urls' in info: - if info['connect_urls']: + if "connect_urls" in info: + if info["connect_urls"]: connect_urls = [] - for connect_url in info['connect_urls']: - scheme = '' - if self._current_server.uri.scheme == 'tls': - scheme = 'tls' + for connect_url in info["connect_urls"]: + scheme = "" + if self._current_server.uri.scheme == "tls": + scheme = "tls" else: - scheme = 'nats' + scheme = "nats" uri = urlparse(f"{scheme}://{connect_url}") srv = Srv(uri) srv.discovered = True # Check whether we should reuse the original hostname. - if 'tls_required' in self._server_info and self._server_info['tls_required'] \ - and self._host_is_ip(uri.hostname): + if ( + "tls_required" in self._server_info + and self._server_info["tls_required"] + and self._host_is_ip(uri.hostname) + ): srv.tls_name = self._current_server.uri.hostname # Filter for any similar server in the server pool already. @@ -1891,7 +1902,11 @@ def _process_info( for srv in connect_urls: self._server_pool.append(srv) - if not initial_connection and connect_urls and self._discovered_server_cb: + if ( + not initial_connection + and connect_urls + and self._discovered_server_cb + ): self._discovered_server_cb() def _host_is_ip(self, connect_url: Optional[str]) -> bool: @@ -1922,13 +1937,13 @@ async def _process_connect_init(self) -> None: else: hostname = self._current_server.uri.hostname - handshake_first = self.options['tls_handshake_first'] + handshake_first = self.options["tls_handshake_first"] if handshake_first: await self._transport.connect_tls( hostname, self.ssl_context, DEFAULT_BUFFER_SIZE, - self.options['connect_timeout'], + self.options["connect_timeout"], ) connection_completed = self._transport.readline() @@ -1955,17 +1970,20 @@ async def _process_connect_init(self) -> None: self._process_info(srv_info, initial_connection=True) - if 'version' in self._server_info: - self._current_server.server_version = self._server_info['version'] + if "version" in self._server_info: + self._current_server.server_version = self._server_info["version"] - if 'max_payload' in self._server_info: + if "max_payload" in self._server_info: self._max_payload = self._server_info["max_payload"] - if 'client_id' in self._server_info: + if "client_id" in self._server_info: self._client_id = self._server_info["client_id"] - if 'tls_required' in self._server_info and self._server_info[ - 'tls_required'] and self._current_server.uri.scheme != "ws": + if ( + "tls_required" in self._server_info + and self._server_info["tls_required"] + and self._current_server.uri.scheme != "ws" + ): if not handshake_first: await self._transport.drain() # just in case something is left @@ -1974,7 +1992,7 @@ async def _process_connect_init(self) -> None: hostname, self.ssl_context, DEFAULT_BUFFER_SIZE, - self.options['connect_timeout'], + self.options["connect_timeout"], ) # Refresh state of parser upon reconnect. @@ -1987,9 +2005,7 @@ async def _process_connect_init(self) -> None: await self._transport.drain() if self.options["verbose"]: future = self._transport.readline() - next_op = await asyncio.wait_for( - future, self.options["connect_timeout"] - ) + next_op = await asyncio.wait_for(future, self.options["connect_timeout"]) if OK_OP in next_op: # Do nothing pass @@ -2000,15 +2016,13 @@ async def _process_connect_init(self) -> None: # FIXME: Maybe handling could be more special here, # checking for errors.AuthorizationError for example. # await self._process_err(err_msg) - raise errors.Error("nats: " + err_msg.rstrip('\r\n')) + raise errors.Error("nats: " + err_msg.rstrip("\r\n")) self._transport.write(PING_PROTO) await self._transport.drain() future = self._transport.readline() - next_op = await asyncio.wait_for( - future, self.options["connect_timeout"] - ) + next_op = await asyncio.wait_for(future, self.options["connect_timeout"]) if PONG_PROTO in next_op: self._status = Client.CONNECTED @@ -2019,14 +2033,12 @@ async def _process_connect_init(self) -> None: # FIXME: Maybe handling could be more special here, # checking for ErrAuthorization for example. # await self._process_err(err_msg) - raise errors.Error("nats: " + err_msg.rstrip('\r\n')) + raise errors.Error("nats: " + err_msg.rstrip("\r\n")) if PONG_PROTO in next_op: self._status = Client.CONNECTED - self._reading_task = asyncio.get_running_loop().create_task( - self._read_loop() - ) + self._reading_task = asyncio.get_running_loop().create_task(self._read_loop()) self._pongs = [] self._pings_outstanding = 0 self._ping_interval_task = asyncio.get_running_loop().create_task( @@ -2034,13 +2046,9 @@ async def _process_connect_init(self) -> None: ) # Task for kicking the flusher queue - self._flusher_task = asyncio.get_running_loop().create_task( - self._flusher() - ) + self._flusher_task = asyncio.get_running_loop().create_task(self._flusher()) - async def _send_ping( - self, future: Optional[asyncio.Future] = None - ) -> None: + async def _send_ping(self, future: Optional[asyncio.Future] = None) -> None: assert self._transport, "Client.connect must be called first" if future is None: future = asyncio.Future() @@ -2085,8 +2093,7 @@ async def _ping_interval(self) -> None: continue try: self._pings_outstanding += 1 - if self._pings_outstanding > self.options[ - "max_outstanding_pings"]: + if self._pings_outstanding > self.options["max_outstanding_pings"]: await self._process_op_err(ErrStaleConnection()) return await self._send_ping() @@ -2124,7 +2131,7 @@ async def _read_loop(self) -> None: except asyncio.CancelledError: break except Exception as ex: - _logger.error('nats: encountered error', exc_info=ex) + _logger.error("nats: encountered error", exc_info=ex) break # except asyncio.InvalidStateError: # pass diff --git a/nats/aio/errors.py b/nats/aio/errors.py index d7cc7aca..66fd68d4 100644 --- a/nats/aio/errors.py +++ b/nats/aio/errors.py @@ -24,6 +24,7 @@ class NatsError(nats.errors.Error): Please use `nats.errors.Error` instead. """ + pass @@ -34,6 +35,7 @@ class ErrConnectionClosed(nats.errors.ConnectionClosedError): Please use `nats.errors.ConnectionClosedError` instead. """ + pass @@ -44,6 +46,7 @@ class ErrDrainTimeout(nats.errors.DrainTimeoutError): Please use `nats.errors.DrainTimeoutError` instead. """ + pass @@ -54,6 +57,7 @@ class ErrInvalidUserCredentials(nats.errors.InvalidUserCredentialsError): Please use `nats.errors.InvalidUserCredentialsError` instead. """ + pass @@ -64,6 +68,7 @@ class ErrInvalidCallbackType(nats.errors.InvalidCallbackTypeError): Please use `nats.errors.InvalidCallbackTypeError` instead. """ + pass @@ -74,6 +79,7 @@ class ErrConnectionReconnecting(nats.errors.ConnectionReconnectingError): Please use `nats.errors.ConnectionReconnectingError` instead. """ + pass @@ -84,6 +90,7 @@ class ErrConnectionDraining(nats.errors.ConnectionDrainingError): Please use `nats.errors.ConnectionDrainingError` instead. """ + pass @@ -94,6 +101,7 @@ class ErrMaxPayload(nats.errors.MaxPayloadError): Please use `nats.errors.MaxPayloadError` instead. """ + pass @@ -104,6 +112,7 @@ class ErrStaleConnection(nats.errors.StaleConnectionError): Please use `nats.errors.StaleConnectionError` instead. """ + pass @@ -114,6 +123,7 @@ class ErrJsonParse(nats.errors.JsonParseError): Please use `nats.errors.JsonParseError` instead. """ + pass @@ -124,6 +134,7 @@ class ErrSecureConnRequired(nats.errors.SecureConnRequiredError): Please use `nats.errors.SecureConnRequiredError` instead. """ + pass @@ -134,6 +145,7 @@ class ErrSecureConnWanted(nats.errors.SecureConnWantedError): Please use `nats.errors.SecureConnWantedError` instead. """ + pass @@ -144,6 +156,7 @@ class ErrSecureConnFailed(nats.errors.SecureConnFailedError): Please use `nats.errors.SecureConnFailedError` instead. """ + pass @@ -154,6 +167,7 @@ class ErrBadSubscription(nats.errors.BadSubscriptionError): Please use `nats.errors.BadSubscriptionError` instead. """ + pass @@ -164,6 +178,7 @@ class ErrBadSubject(nats.errors.BadSubjectError): Please use `nats.errors.BadSubjectError` instead. """ + pass @@ -174,6 +189,7 @@ class ErrSlowConsumer(nats.errors.SlowConsumerError): Please use `nats.errors.SlowConsumerError` instead. """ + pass @@ -184,6 +200,7 @@ class ErrTimeout(nats.errors.TimeoutError): Please use `nats.errors.TimeoutError` instead. """ + pass @@ -194,6 +211,7 @@ class ErrBadTimeout(nats.errors.BadTimeoutError): Please use `nats.errors.BadTimeoutError` instead. """ + pass @@ -204,6 +222,7 @@ class ErrAuthorization(nats.errors.AuthorizationError): Please use `nats.errors.AuthorizationError` instead. """ + pass @@ -214,4 +233,5 @@ class ErrNoServers(nats.errors.NoServersError): Please use `nats.errors.NoServersError` instead. """ + pass diff --git a/nats/aio/msg.py b/nats/aio/msg.py index 8eccd991..ddded8af 100644 --- a/nats/aio/msg.py +++ b/nats/aio/msg.py @@ -40,10 +40,11 @@ class Msg: """ Msg represents a message delivered by NATS. """ + _client: NATS - subject: str = '' - reply: str = '' - data: bytes = b'' + subject: str = "" + reply: str = "" + data: bytes = b"" headers: Optional[Dict[str, str]] = None _metadata: Optional[Metadata] = None @@ -57,8 +58,8 @@ class Ack: Term = b"+TERM" # Reply metadata... - Prefix0 = '$JS' - Prefix1 = 'ACK' + Prefix0 = "$JS" + Prefix1 = "ACK" Domain = 2 AccHash = 3 Stream = 4 @@ -82,7 +83,7 @@ def sid(self) -> int: sid returns the subscription ID from a message. """ if self._sid is None: - raise Error('sid not set') + raise Error("sid not set") return self._sid async def respond(self, data: bytes) -> None: @@ -90,9 +91,9 @@ async def respond(self, data: bytes) -> None: respond replies to the inbox of the message if there is one. """ if not self.reply: - raise Error('no reply subject available') + raise Error("no reply subject available") if not self._client: - raise Error('client not set') + raise Error("client not set") await self._client.publish(self.reply, data, headers=self.headers) @@ -122,9 +123,9 @@ async def nak(self, delay: Union[int, float, None] = None) -> None: payload = Msg.Ack.Nak json_args = dict() if delay: - json_args['delay'] = int(delay * 10**9) # from seconds to ns + json_args["delay"] = int(delay * 10**9) # from seconds to ns if json_args: - payload += (b' ' + json.dumps(json_args).encode()) + payload += b" " + json.dumps(json_args).encode() await self._client.publish(self.reply, payload) self._ackd = True @@ -133,7 +134,7 @@ async def in_progress(self) -> None: in_progress acknowledges a message delivered by JetStream is still being worked on. Unlike other types of acks, an in-progress ack (+WPI) can be done multiple times. """ - if self.reply is None or self.reply == '': + if self.reply is None or self.reply == "": raise NotJSMessageError await self._client.publish(self.reply, Msg.Ack.Progress) @@ -165,7 +166,7 @@ def _get_metadata_fields(self, reply: Optional[str]) -> List[str]: return Msg.Metadata._get_metadata_fields(reply) def _check_reply(self) -> None: - if self.reply is None or self.reply == '': + if self.reply is None or self.reply == "": raise NotJSMessageError if self._ackd: raise MsgAlreadyAckdError(self) @@ -184,6 +185,7 @@ class Metadata: - consumer is the name of the consumer. """ + sequence: SequencePair num_pending: int num_delivered: int @@ -197,6 +199,7 @@ class SequencePair: """ SequencePair represents a pair of consumer and stream sequence. """ + consumer: int stream: int @@ -204,23 +207,21 @@ class SequencePair: def _get_metadata_fields(cls, reply: Optional[str]) -> List[str]: if not reply: raise NotJSMessageError - tokens = reply.split('.') - if (len(tokens) == _V1_TOKEN_COUNT or - len(tokens) >= _V2_TOKEN_COUNT-1) and \ - tokens[0] == Msg.Ack.Prefix0 and \ - tokens[1] == Msg.Ack.Prefix1: + tokens = reply.split(".") + if ( + (len(tokens) == _V1_TOKEN_COUNT or len(tokens) >= _V2_TOKEN_COUNT - 1) + and tokens[0] == Msg.Ack.Prefix0 + and tokens[1] == Msg.Ack.Prefix1 + ): return tokens raise NotJSMessageError @classmethod def _from_reply(cls, reply: str) -> Msg.Metadata: - """Construct the metadata from the reply string - """ + """Construct the metadata from the reply string""" tokens = cls._get_metadata_fields(reply) if len(tokens) == _V1_TOKEN_COUNT: - t = datetime.datetime.fromtimestamp( - int(tokens[7]) / 1_000_000_000.0 - ) + t = datetime.datetime.fromtimestamp(int(tokens[7]) / 1_000_000_000.0) return cls( sequence=Msg.Metadata.SequencePair( stream=int(tokens[5]), diff --git a/nats/aio/subscription.py b/nats/aio/subscription.py index a322c68f..4139f6fe 100644 --- a/nats/aio/subscription.py +++ b/nats/aio/subscription.py @@ -26,6 +26,7 @@ from uuid import uuid4 from nats import errors + # Default Pending Limits of Subscriptions from nats.aio.msg import Msg @@ -63,8 +64,8 @@ def __init__( self, conn, id: int = 0, - subject: str = '', - queue: str = '', + subject: str = "", + queue: str = "", cb: Optional[Callable[[Msg], Awaitable[None]]] = None, future: Optional[asyncio.Future] = None, max_msgs: int = 0, @@ -123,7 +124,7 @@ def messages(self) -> AsyncIterator[Msg]: subscription. :: - + nc = await nats.connect() sub = await nc.subscribe('foo') # Use `async for` which implicitly awaits messages @@ -176,9 +177,7 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg: raise errors.ConnectionClosedError if self._cb: - raise errors.Error( - 'nats: next_msg cannot be used in async subscriptions' - ) + raise errors.Error("nats: next_msg cannot be used in async subscriptions") task_name = str(uuid4()) try: @@ -210,11 +209,10 @@ def _start(self, error_cb): Creates the resources for the subscription to start processing messages. """ if self._cb: - if not asyncio.iscoroutinefunction(self._cb) and \ - not (hasattr(self._cb, "func") and asyncio.iscoroutinefunction(self._cb.func)): - raise errors.Error( - "nats: must use coroutine for subscriptions" - ) + if not asyncio.iscoroutinefunction(self._cb) and not ( + hasattr(self._cb, "func") and asyncio.iscoroutinefunction(self._cb.func) + ): + raise errors.Error("nats: must use coroutine for subscriptions") self._wait_for_msgs_task = asyncio.get_running_loop().create_task( self._wait_for_msgs(error_cb) @@ -281,8 +279,7 @@ async def unsubscribe(self, limit: int = 0): raise errors.BadSubscriptionError self._max_msgs = limit - if limit == 0 or (self._received >= limit - and self._pending_queue.empty()): + if limit == 0 or (self._received >= limit and self._pending_queue.empty()): self._closed = True self._stop_processing() self._conn._remove_sub(self._id) @@ -328,14 +325,17 @@ async def _wait_for_msgs(self, error_cb) -> None: self._pending_queue.task_done() # Apply auto unsubscribe checks after having processed last msg. - if self._max_msgs > 0 and self._received >= self._max_msgs and self._pending_queue.empty: + if ( + self._max_msgs > 0 + and self._received >= self._max_msgs + and self._pending_queue.empty + ): self._stop_processing() except asyncio.CancelledError: break class _SubscriptionMessageIterator: - def __init__(self, sub: Subscription) -> None: self._sub: Subscription = sub self._queue: asyncio.Queue[Msg] = sub._pending_queue @@ -351,9 +351,7 @@ def __aiter__(self) -> _SubscriptionMessageIterator: async def __anext__(self) -> Msg: get_task = asyncio.get_running_loop().create_task(self._queue.get()) tasks: List[asyncio.Future] = [get_task, self._unsubscribed_future] - finished, _ = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) + finished, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) sub = self._sub if get_task in finished: diff --git a/nats/aio/transport.py b/nats/aio/transport.py index 9bde8475..3531fcc2 100644 --- a/nats/aio/transport.py +++ b/nats/aio/transport.py @@ -15,11 +15,8 @@ class Transport(abc.ABC): - @abc.abstractmethod - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + async def connect(self, uri: ParseResult, buffer_size: int, connect_timeout: int): """ Connects to a server using the implemented transport. The uri passed is of type ParseResult that can be obtained calling urllib.parse.urlparse. @@ -108,22 +105,20 @@ def __bool__(self): class TcpTransport(Transport): - def __init__(self): self._bare_io_reader: Optional[asyncio.StreamReader] = None self._io_reader: Optional[asyncio.StreamReader] = None self._bare_io_writer: Optional[asyncio.StreamWriter] = None self._io_writer: Optional[asyncio.StreamWriter] = None - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + async def connect(self, uri: ParseResult, buffer_size: int, connect_timeout: int): r, w = await asyncio.wait_for( asyncio.open_connection( host=uri.hostname, port=uri.port, limit=buffer_size, - ), connect_timeout + ), + connect_timeout, ) # We keep a reference to the initial transport we used when # establishing the connection in case we later upgrade to TLS @@ -142,7 +137,7 @@ async def connect_tls( buffer_size: int, connect_timeout: int, ) -> None: - assert self._io_writer, f'{type(self).__name__}.connect must be called first' + assert self._io_writer, f"{type(self).__name__}.connect must be called first" # manually recreate the stream reader/writer with a tls upgraded transport reader = asyncio.StreamReader() @@ -152,7 +147,7 @@ async def connect_tls( protocol, ssl_context, # hostname here will be passed directly as string - server_hostname=uri if isinstance(uri, str) else uri.hostname + server_hostname=uri if isinstance(uri, str) else uri.hostname, ) transport = await asyncio.wait_for(transport_future, connect_timeout) writer = asyncio.StreamWriter( @@ -167,7 +162,7 @@ def writelines(self, payload): return self._io_writer.writelines(payload) async def read(self, buffer_size: int): - assert self._io_reader, f'{type(self).__name__}.connect must be called first' + assert self._io_reader, f"{type(self).__name__}.connect must be called first" return await self._io_reader.read(buffer_size) async def readline(self): @@ -190,7 +185,6 @@ def __bool__(self): class WebSocketTransport(Transport): - def __init__(self): if not aiohttp: raise ImportError( @@ -202,13 +196,9 @@ def __init__(self): self._close_task = asyncio.Future() self._using_tls: Optional[bool] = None - async def connect( - self, uri: ParseResult, buffer_size: int, connect_timeout: int - ): + async def connect(self, uri: ParseResult, buffer_size: int, connect_timeout: int): # for websocket library, the uri must contain the scheme already - self._ws = await self._client.ws_connect( - uri.geturl(), timeout=connect_timeout - ) + self._ws = await self._client.ws_connect(uri.geturl(), timeout=connect_timeout) self._using_tls = False async def connect_tls( @@ -226,7 +216,7 @@ async def connect_tls( self._ws = await self._client.ws_connect( uri if isinstance(uri, str) else uri.geturl(), ssl=ssl_context, - timeout=connect_timeout + timeout=connect_timeout, ) self._using_tls = True @@ -244,7 +234,7 @@ async def readline(self): data = await self._ws.receive() if data.type == aiohttp.WSMsgType.CLOSED: # if the connection terminated abruptly, return empty binary data to raise unexpected EOF - return b'' + return b"" return data.data async def drain(self): diff --git a/nats/errors.py b/nats/errors.py index 953a4ca2..6127653c 100644 --- a/nats/errors.py +++ b/nats/errors.py @@ -25,154 +25,130 @@ class Error(Exception): class TimeoutError(Error, asyncio.TimeoutError): - def __str__(self) -> str: return "nats: timeout" class NoRespondersError(Error): - def __str__(self) -> str: return "nats: no responders available for request" class StaleConnectionError(Error): - def __str__(self) -> str: return "nats: stale connection" class OutboundBufferLimitError(Error): - def __str__(self) -> str: return "nats: outbound buffer limit exceeded" class UnexpectedEOF(StaleConnectionError): - def __str__(self) -> str: return "nats: unexpected EOF" class FlushTimeoutError(TimeoutError): - def __str__(self) -> str: return "nats: flush timeout" class ConnectionClosedError(Error): - def __str__(self) -> str: return "nats: connection closed" class SecureConnRequiredError(Error): - def __str__(self) -> str: return "nats: secure connection required" class SecureConnWantedError(Error): - def __str__(self) -> str: return "nats: secure connection not available" class SecureConnFailedError(Error): - def __str__(self) -> str: return "nats: secure connection failed" class BadSubscriptionError(Error): - def __str__(self) -> str: return "nats: invalid subscription" class BadSubjectError(Error): - def __str__(self) -> str: return "nats: invalid subject" class SlowConsumerError(Error): - - def __init__( - self, subject: str, reply: str, sid: int, sub: Subscription - ) -> None: + def __init__(self, subject: str, reply: str, sid: int, sub: Subscription) -> None: self.subject = subject self.reply = reply self.sid = sid self.sub = sub def __str__(self) -> str: - return "nats: slow consumer, messages dropped subject: " \ - f"{self.subject}, sid: {self.sid}, sub: {self.sub}" + return ( + "nats: slow consumer, messages dropped subject: " + f"{self.subject}, sid: {self.sid}, sub: {self.sub}" + ) class BadTimeoutError(Error): - def __str__(self) -> str: return "nats: timeout invalid" class AuthorizationError(Error): - def __str__(self) -> str: return "nats: authorization failed" class NoServersError(Error): - def __str__(self) -> str: return "nats: no servers available for connection" class JsonParseError(Error): - def __str__(self) -> str: return "nats: connect message, json parse err" class MaxPayloadError(Error): - def __str__(self) -> str: return "nats: maximum payload exceeded" class DrainTimeoutError(TimeoutError): - def __str__(self) -> str: return "nats: draining connection timed out" class ConnectionDrainingError(Error): - def __str__(self) -> str: return "nats: connection draining" class ConnectionReconnectingError(Error): - def __str__(self) -> str: return "nats: connection reconnecting" class InvalidUserCredentialsError(Error): - def __str__(self) -> str: return "nats: invalid user credentials" class InvalidCallbackTypeError(Error): - def __str__(self) -> str: return "nats: callbacks must be coroutine functions" class ProtocolError(Error): - def __str__(self) -> str: return "nats: protocol error" @@ -188,7 +164,6 @@ def __str__(self) -> str: class MsgAlreadyAckdError(Error): - def __init__(self, msg=None) -> None: self._msg = msg diff --git a/nats/js/api.py b/nats/js/api.py index 3205b024..d96f25a0 100644 --- a/nats/js/api.py +++ b/nats/js/api.py @@ -36,7 +36,7 @@ class Header(str, Enum): DEFAULT_PREFIX = "$JS.API" -INBOX_PREFIX = b'_INBOX.' +INBOX_PREFIX = b"_INBOX." class StatusCode(str, Enum): @@ -58,8 +58,7 @@ class Base: @staticmethod def _convert(resp: Dict[str, Any], field: str, type: type[Base]) -> None: - """Convert the field into the given type in place. - """ + """Convert the field into the given type in place.""" data = resp.get(field, None) if data is None: resp[field] = None @@ -70,8 +69,7 @@ def _convert(resp: Dict[str, Any], field: str, type: type[Base]) -> None: @staticmethod def _convert_nanoseconds(resp: Dict[str, Any], field: str) -> None: - """Convert the given field from nanoseconds to seconds in place. - """ + """Convert the given field from nanoseconds to seconds in place.""" val = resp.get(field, None) if val is not None: val = val / _NANOSECOND @@ -79,8 +77,7 @@ def _convert_nanoseconds(resp: Dict[str, Any], field: str) -> None: @staticmethod def _to_nanoseconds(val: Optional[float]) -> Optional[int]: - """Convert the value from seconds to nanoseconds. - """ + """Convert the value from seconds to nanoseconds.""" if val is None: # We use 0 to avoid sending null to Go servers. return 0 @@ -99,13 +96,11 @@ def from_response(cls: type[_B], resp: Dict[str, Any]) -> _B: return cls(**params) def evolve(self: _B, **params) -> _B: - """Return a copy of the instance with the passed values replaced. - """ + """Return a copy of the instance with the passed values replaced.""" return replace(self, **params) def as_dict(self) -> Dict[str, object]: - """Return the object converted into an API-friendly dict. - """ + """Return the object converted into an API-friendly dict.""" result = {} for field in fields(self): val = getattr(self, field.name) @@ -125,6 +120,7 @@ class PubAck(Base): """ PubAck is the response of publishing a message to JetStream. """ + stream: str seq: int domain: Optional[str] = None @@ -161,14 +157,14 @@ class StreamSource(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'external', ExternalStream) - cls._convert(resp, 'subject_transforms', SubjectTransform) + cls._convert(resp, "external", ExternalStream) + cls._convert(resp, "subject_transforms", SubjectTransform) return super().from_response(resp) def as_dict(self) -> Dict[str, object]: result = super().as_dict() if self.subject_transforms: - result['subject_transform'] = [ + result["subject_transform"] = [ tr.as_dict() for tr in self.subject_transforms ] return result @@ -202,7 +198,7 @@ class StreamState(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'lost', LostStreamData) + cls._convert(resp, "lost", LostStreamData) return super().from_response(resp) @@ -247,6 +243,7 @@ class RePublish(Base): RePublish is for republishing messages once committed to a stream. The original subject cis remapped from the subject pattern to the destination pattern. """ + src: Optional[str] = None dest: Optional[str] = None headers_only: Optional[bool] = None @@ -255,6 +252,7 @@ class RePublish(Base): @dataclass class SubjectTransform(Base): """Subject transform to apply to matching messages.""" + src: str dest: str @@ -268,6 +266,7 @@ class StreamConfig(Base): """ StreamConfig represents the configuration of a stream. """ + name: Optional[str] = None description: Optional[str] = None subjects: Optional[List[str]] = None @@ -311,24 +310,25 @@ class StreamConfig(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert_nanoseconds(resp, 'max_age') - cls._convert_nanoseconds(resp, 'duplicate_window') - cls._convert(resp, 'placement', Placement) - cls._convert(resp, 'mirror', StreamSource) - cls._convert(resp, 'sources', StreamSource) - cls._convert(resp, 'republish', RePublish) - cls._convert(resp, 'subject_transform', SubjectTransform) + cls._convert_nanoseconds(resp, "max_age") + cls._convert_nanoseconds(resp, "duplicate_window") + cls._convert(resp, "placement", Placement) + cls._convert(resp, "mirror", StreamSource) + cls._convert(resp, "sources", StreamSource) + cls._convert(resp, "republish", RePublish) + cls._convert(resp, "subject_transform", SubjectTransform) return super().from_response(resp) def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['duplicate_window'] = self._to_nanoseconds( - self.duplicate_window - ) - result['max_age'] = self._to_nanoseconds(self.max_age) + result["duplicate_window"] = self._to_nanoseconds(self.duplicate_window) + result["max_age"] = self._to_nanoseconds(self.max_age) if self.sources: - result['sources'] = [src.as_dict() for src in self.sources] - if self.compression and (self.compression != StoreCompression.NONE and self.compression != StoreCompression.S2): + result["sources"] = [src.as_dict() for src in self.sources] + if self.compression and ( + self.compression != StoreCompression.NONE + and self.compression != StoreCompression.S2 + ): raise ValueError( "nats: invalid store compression type: %s" % self.compression ) @@ -354,7 +354,7 @@ class ClusterInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'replicas', PeerInfo) + cls._convert(resp, "replicas", PeerInfo) return super().from_response(resp) @@ -363,6 +363,7 @@ class StreamInfo(Base): """ StreamInfo is the latest information about a stream from JetStream. """ + config: StreamConfig state: StreamState mirror: Optional[StreamSourceInfo] = None @@ -372,11 +373,11 @@ class StreamInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'config', StreamConfig) - cls._convert(resp, 'state', StreamState) - cls._convert(resp, 'mirror', StreamSourceInfo) - cls._convert(resp, 'sources', StreamSourceInfo) - cls._convert(resp, 'cluster', ClusterInfo) + cls._convert(resp, "config", StreamConfig) + cls._convert(resp, "state", StreamState) + cls._convert(resp, "mirror", StreamSourceInfo) + cls._convert(resp, "sources", StreamSourceInfo) + cls._convert(resp, "cluster", ClusterInfo) return super().from_response(resp) @@ -385,6 +386,7 @@ class StreamsListIterator(Iterable): """ StreamsListIterator is an iterator for streams list responses from JetStream. """ + def __init__(self, offset: int, total: int, streams: List[Dict[str, any]]) -> None: self.offset = offset self.total = total @@ -456,6 +458,7 @@ class ConsumerConfig(Base): References: * `Consumers `_ """ + name: Optional[str] = None durable_name: Optional[str] = None description: Optional[str] = None @@ -497,22 +500,20 @@ class ConsumerConfig(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert_nanoseconds(resp, 'ack_wait') - cls._convert_nanoseconds(resp, 'idle_heartbeat') - cls._convert_nanoseconds(resp, 'inactive_threshold') - if 'backoff' in resp: - resp['backoff'] = [val / _NANOSECOND for val in resp['backoff']] + cls._convert_nanoseconds(resp, "ack_wait") + cls._convert_nanoseconds(resp, "idle_heartbeat") + cls._convert_nanoseconds(resp, "inactive_threshold") + if "backoff" in resp: + resp["backoff"] = [val / _NANOSECOND for val in resp["backoff"]] return super().from_response(resp) def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['ack_wait'] = self._to_nanoseconds(self.ack_wait) - result['idle_heartbeat'] = self._to_nanoseconds(self.idle_heartbeat) - result['inactive_threshold'] = self._to_nanoseconds( - self.inactive_threshold - ) + result["ack_wait"] = self._to_nanoseconds(self.ack_wait) + result["idle_heartbeat"] = self._to_nanoseconds(self.idle_heartbeat) + result["inactive_threshold"] = self._to_nanoseconds(self.inactive_threshold) if self.backoff: - result['backoff'] = [self._to_nanoseconds(i) for i in self.backoff] + result["backoff"] = [self._to_nanoseconds(i) for i in self.backoff] return result @@ -529,6 +530,7 @@ class ConsumerInfo(Base): """ ConsumerInfo represents the info about the consumer. """ + name: str stream_name: str config: ConsumerConfig @@ -545,10 +547,10 @@ class ConsumerInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'delivered', SequenceInfo) - cls._convert(resp, 'ack_floor', SequenceInfo) - cls._convert(resp, 'config', ConsumerConfig) - cls._convert(resp, 'cluster', ClusterInfo) + cls._convert(resp, "delivered", SequenceInfo) + cls._convert(resp, "ack_floor", SequenceInfo) + cls._convert(resp, "config", ConsumerConfig) + cls._convert(resp, "cluster", ClusterInfo) return super().from_response(resp) @@ -580,7 +582,7 @@ class Tier(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'limits', AccountLimits) + cls._convert(resp, "limits", AccountLimits) return super().from_response(resp) @@ -599,6 +601,7 @@ class AccountInfo(Base): References: * `Account Information `_ """ + # NOTE: These fields are shared with Tier type as well. memory: int storage: int @@ -612,10 +615,10 @@ class AccountInfo(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'limits', AccountLimits) - cls._convert(resp, 'api', APIStats) + cls._convert(resp, "limits", AccountLimits) + cls._convert(resp, "api", APIStats) info = super().from_response(resp) - tiers = resp.get('tiers', None) + tiers = resp.get("tiers", None) if tiers: result = {} for k, v in tiers.items(): @@ -651,6 +654,7 @@ class KeyValueConfig(Base): """ KeyValueConfig is the configuration of a KeyValue store. """ + bucket: str description: Optional[str] = None max_value_size: Optional[int] = None @@ -665,7 +669,7 @@ class KeyValueConfig(Base): def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['ttl'] = self._to_nanoseconds(self.ttl) + result["ttl"] = self._to_nanoseconds(self.ttl) return result @@ -674,6 +678,7 @@ class StreamPurgeRequest(Base): """ StreamPurgeRequest is optional request information to the purge API. """ + # Purge up to but not including sequence. seq: Optional[int] = None # Subject to match against messages for the purge command. @@ -687,6 +692,7 @@ class ObjectStoreConfig(Base): """ ObjectStoreConfig is the configurigation of an ObjectStore. """ + bucket: Optional[str] = None description: Optional[str] = None ttl: Optional[float] = None @@ -697,7 +703,7 @@ class ObjectStoreConfig(Base): def as_dict(self) -> Dict[str, object]: result = super().as_dict() - result['ttl'] = self._to_nanoseconds(self.ttl) + result["ttl"] = self._to_nanoseconds(self.ttl) return result @@ -706,6 +712,7 @@ class ObjectLink(Base): """ ObjectLink is used to embed links to other buckets and objects. """ + # Bucket is the name of the other object store. bucket: str # Name can be used to link to a single object. @@ -724,7 +731,7 @@ class ObjectMetaOptions(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'link', ObjectLink) + cls._convert(resp, "link", ObjectLink) return super().from_response(resp) @@ -733,6 +740,7 @@ class ObjectMeta(Base): """ ObjectMeta is high level information about an object. """ + name: Optional[str] = None description: Optional[str] = None headers: Optional[dict] = None @@ -741,7 +749,7 @@ class ObjectMeta(Base): @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'options', ObjectMetaOptions) + cls._convert(resp, "options", ObjectMetaOptions) return super().from_response(resp) @@ -750,6 +758,7 @@ class ObjectInfo(Base): """ ObjectInfo is meta plus instance information. """ + name: str bucket: str nuid: str @@ -779,5 +788,5 @@ def is_link(self) -> bool: @classmethod def from_response(cls, resp: Dict[str, Any]): - cls._convert(resp, 'options', ObjectMetaOptions) + cls._convert(resp, "options", ObjectMetaOptions) return super().from_response(resp) diff --git a/nats/js/client.py b/nats/js/client.py index 0ee13654..4ccad0da 100644 --- a/nats/js/client.py +++ b/nats/js/client.py @@ -26,12 +26,21 @@ from nats.aio.msg import Msg from nats.aio.subscription import Subscription from nats.js import api -from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError, FetchTimeoutError +from nats.js.errors import ( + BadBucketError, + BucketNotFoundError, + InvalidBucketNameError, + NotFoundError, + FetchTimeoutError, +) from nats.js.kv import KeyValue from nats.js.manager import JetStreamManager from nats.js.object_store import ( - VALID_BUCKET_RE, OBJ_ALL_CHUNKS_PRE_TEMPLATE, OBJ_ALL_META_PRE_TEMPLATE, - OBJ_STREAM_TEMPLATE, ObjectStore + VALID_BUCKET_RE, + OBJ_ALL_CHUNKS_PRE_TEMPLATE, + OBJ_ALL_META_PRE_TEMPLATE, + OBJ_STREAM_TEMPLATE, + ObjectStore, ) if TYPE_CHECKING: @@ -39,13 +48,13 @@ NO_RESPONDERS_STATUS = "503" -NATS_HDR_LINE = bytearray(b'NATS/1.0') +NATS_HDR_LINE = bytearray(b"NATS/1.0") NATS_HDR_LINE_SIZE = len(NATS_HDR_LINE) -_CRLF_ = b'\r\n' +_CRLF_ = b"\r\n" _CRLF_LEN_ = len(_CRLF_) KV_STREAM_TEMPLATE = "KV_{bucket}" KV_PRE_TEMPLATE = "$KV.{bucket}." -Callback = Callable[['Msg'], Awaitable[None]] +Callback = Callable[["Msg"], Awaitable[None]] # For JetStream the default pending limits are larger. DEFAULT_JS_SUB_PENDING_MSGS_LIMIT = 512 * 1024 @@ -106,7 +115,9 @@ def __init__( self._publish_async_completed_event = asyncio.Event() self._publish_async_completed_event.set() - self._publish_async_pending_semaphore = asyncio.Semaphore(publish_async_max_pending) + self._publish_async_pending_semaphore = asyncio.Semaphore( + publish_async_max_pending + ) @property def _jsm(self) -> JetStreamManager: @@ -120,19 +131,19 @@ async def _init_async_reply(self) -> None: self._publish_async_futures = {} self._async_reply_prefix = self._nc._inbox_prefix[:] - self._async_reply_prefix.extend(b'.') + self._async_reply_prefix.extend(b".") self._async_reply_prefix.extend(self._nc._nuid.next()) - self._async_reply_prefix.extend(b'.') + self._async_reply_prefix.extend(b".") async_reply_subject = self._async_reply_prefix[:] - async_reply_subject.extend(b'*') + async_reply_subject.extend(b"*") await self._nc.subscribe( async_reply_subject.decode(), cb=self._handle_async_reply ) async def _handle_async_reply(self, msg: Msg) -> None: - token = msg.subject[len(self._nc._inbox_prefix) + 22 + 2:] + token = msg.subject[len(self._nc._inbox_prefix) + 22 + 2 :] future = self._publish_async_futures.get(token) if not future: @@ -149,8 +160,8 @@ async def _handle_async_reply(self, msg: Msg) -> None: # Handle response errors try: resp = json.loads(msg.data) - if 'error' in resp: - err = nats.js.errors.APIError.from_error(resp['error']) + if "error" in resp: + err = nats.js.errors.APIError.from_error(resp["error"]) future.set_exception(err) return @@ -162,10 +173,10 @@ async def _handle_async_reply(self, msg: Msg) -> None: async def publish( self, subject: str, - payload: bytes = b'', + payload: bytes = b"", timeout: Optional[float] = None, stream: Optional[str] = None, - headers: Optional[Dict[str, Any]] = None + headers: Optional[Dict[str, Any]] = None, ) -> api.PubAck: """ publish emits a new message to JetStream and waits for acknowledgement. @@ -188,14 +199,14 @@ async def publish( raise nats.js.errors.NoStreamResponseError resp = json.loads(msg.data) - if 'error' in resp: - raise nats.js.errors.APIError.from_error(resp['error']) + if "error" in resp: + raise nats.js.errors.APIError.from_error(resp["error"]) return api.PubAck.from_response(resp) async def publish_async( self, subject: str, - payload: bytes = b'', + payload: bytes = b"", wait_stall: Optional[float] = None, stream: Optional[str] = None, headers: Optional[Dict] = None, @@ -214,7 +225,9 @@ async def publish_async( hdr[api.Header.EXPECTED_STREAM] = stream try: - await asyncio.wait_for(self._publish_async_pending_semaphore.acquire(), timeout=wait_stall) + await asyncio.wait_for( + self._publish_async_pending_semaphore.acquire(), timeout=wait_stall + ) except (asyncio.TimeoutError, asyncio.CancelledError): raise nats.js.errors.TooManyStalledMsgsError @@ -241,9 +254,7 @@ def handle_done(future): if self._publish_async_completed_event.is_set(): self._publish_async_completed_event.clear() - await self._nc.publish( - subject, payload, reply=inbox.decode(), headers=hdr - ) + await self._nc.publish(subject, payload, reply=inbox.decode(), headers=hdr) return future @@ -453,15 +464,13 @@ async def subscribe_bind( pending_msgs_limit: int = DEFAULT_JS_SUB_PENDING_MSGS_LIMIT, pending_bytes_limit: int = DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, ) -> PushSubscription: - """Push-subscribe to an existing consumer. - """ + """Push-subscribe to an existing consumer.""" # By default, async subscribers wrap the original callback and # auto ack the messages as they are delivered. # # In case ack policy is none then we also do not require to ack. # - if cb and (not manual_ack) and (config.ack_policy - is not api.AckPolicy.NONE): + if cb and (not manual_ack) and (config.ack_policy is not api.AckPolicy.NONE): cb = self._auto_ack_callback(cb) if config.deliver_subject is None: raise TypeError("config.deliver_subject is required") @@ -497,7 +506,6 @@ async def subscribe_bind( @staticmethod def _auto_ack_callback(callback: Callback) -> Callback: - async def new_callback(msg: Msg) -> None: await callback(msg) try: @@ -627,7 +635,7 @@ async def main(): sub = await self._nc.subscribe( deliver.decode(), pending_msgs_limit=pending_msgs_limit, - pending_bytes_limit=pending_bytes_limit + pending_bytes_limit=pending_bytes_limit, ) consumer_name = None # In nats.py v2.7.0 changing the first arg to be 'consumer' instead of 'durable', @@ -664,8 +672,11 @@ def _is_processable_msg(cls, status: Optional[str], msg: Msg) -> bool: @classmethod def _is_temporary_error(cls, status: Optional[str]) -> bool: - if status == api.StatusCode.NO_MESSAGES or status == api.StatusCode.CONFLICT \ - or status == api.StatusCode.REQUEST_TIMEOUT: + if ( + status == api.StatusCode.NO_MESSAGES + or status == api.StatusCode.CONFLICT + or status == api.StatusCode.REQUEST_TIMEOUT + ): return True else: return False @@ -678,14 +689,14 @@ def _is_heartbeat(cls, status: Optional[str]) -> bool: return False @classmethod - def _time_until(cls, timeout: Optional[float], - start_time: float) -> Optional[float]: + def _time_until( + cls, timeout: Optional[float], start_time: float + ) -> Optional[float]: if timeout is None: return None return timeout - (time.monotonic() - start_time) - class _JSI(): - + class _JSI: def __init__( self, js: JetStreamContext, @@ -773,8 +784,7 @@ async def check_flow_control_response(self): except asyncio.CancelledError: break - async def check_for_sequence_mismatch(self, - msg: Msg) -> Optional[bool]: + async def check_for_sequence_mismatch(self, msg: Msg) -> Optional[bool]: self._active = True if not self._cmeta: return None @@ -792,9 +802,7 @@ async def check_for_sequence_mismatch(self, sseq = int(tokens[5]) # stream sequence if self._ordered: - did_reset = await self.reset_ordered_consumer( - self._sseq + 1 - ) + did_reset = await self.reset_ordered_consumer(self._sseq + 1) else: ecs = nats.js.errors.ConsumerSequenceMismatchError( stream_resume_sequence=sseq, @@ -850,9 +858,7 @@ async def reset_ordered_consumer(self, sseq: Optional[int]) -> bool: async def recreate_consumer(self) -> None: try: cinfo = await self._js._jsm.add_consumer( - self._stream, - config=self._ccreq, - timeout=self._js._timeout + self._stream, config=self._ccreq, timeout=self._js._timeout ) self._psub._consumer = cinfo.name except Exception as err: @@ -978,7 +984,7 @@ def __init__( self._stream = stream self._consumer = consumer prefix = self._js._prefix - self._nms = f'{prefix}.CONSUMER.MSG.NEXT.{stream}.{consumer}' + self._nms = f"{prefix}.CONSUMER.MSG.NEXT.{stream}.{consumer}" self._deliver = deliver.decode() @property @@ -1018,16 +1024,14 @@ async def consumer_info(self) -> api.ConsumerInfo: """ consumer_info gets the current info of the consumer from this subscription. """ - info = await self._js._jsm.consumer_info( - self._stream, self._consumer - ) + info = await self._js._jsm.consumer_info(self._stream, self._consumer) return info async def fetch( self, batch: int = 1, timeout: Optional[float] = 5, - heartbeat: Optional[float] = None + heartbeat: Optional[float] = None, ) -> List[Msg]: """ fetch makes a request to JetStream to be delivered a set of messages. @@ -1066,9 +1070,7 @@ async def main(): if timeout is not None and timeout <= 0: raise ValueError("nats: invalid fetch timeout") - expires = int( - timeout * 1_000_000_000 - ) - 100_000 if timeout else None + expires = int(timeout * 1_000_000_000) - 100_000 if timeout else None if batch == 1: msg = await self._fetch_one(expires, timeout, heartbeat) return [msg] @@ -1079,7 +1081,7 @@ async def _fetch_one( self, expires: Optional[int], timeout: Optional[float], - heartbeat: Optional[float] = None + heartbeat: Optional[float] = None, ) -> Msg: queue = self._sub._pending_queue @@ -1100,11 +1102,11 @@ async def _fetch_one( # Make lingering request with expiration and wait for response. next_req = {} - next_req['batch'] = 1 + next_req["batch"] = 1 if expires: - next_req['expires'] = int(expires) + next_req["expires"] = int(expires) if heartbeat: - next_req['idle_heartbeat'] = int( + next_req["idle_heartbeat"] = int( heartbeat * 1_000_000_000 ) # to nanoseconds @@ -1118,9 +1120,7 @@ async def _fetch_one( got_any_response = False while True: try: - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) # Wait for the response or raise timeout. msg = await self._sub.next_msg(timeout=deadline) @@ -1140,9 +1140,7 @@ async def _fetch_one( else: return msg except asyncio.TimeoutError: - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) if deadline is not None and deadline < 0: # No response from the consumer could have been # due to a reconnect while the fetch request, @@ -1157,7 +1155,7 @@ async def _fetch_n( batch: int, expires: Optional[int], timeout: Optional[float], - heartbeat: Optional[float] = None + heartbeat: Optional[float] = None, ) -> List[Msg]: msgs = [] queue = self._sub._pending_queue @@ -1185,14 +1183,14 @@ async def _fetch_n( # First request: Use no_wait to synchronously get as many available # based on the batch size until server sends 'No Messages' status msg. next_req = {} - next_req['batch'] = needed + next_req["batch"] = needed if expires: - next_req['expires'] = expires + next_req["expires"] = expires if heartbeat: - next_req['idle_heartbeat'] = int( + next_req["idle_heartbeat"] = int( heartbeat * 1_000_000_000 ) # to nanoseconds - next_req['no_wait'] = True + next_req["no_wait"] = True await self._nc.publish( self._nms, json.dumps(next_req).encode(), @@ -1223,9 +1221,7 @@ async def _fetch_n( try: for i in range(0, needed): - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) msg = await self._sub.next_msg(timeout=deadline) status = JetStreamContext.is_status_msg(msg) if status == api.StatusCode.NO_MESSAGES: @@ -1251,11 +1247,11 @@ async def _fetch_n( # Second request: lingering request that will block until new messages # are made available and delivered to the client. next_req = {} - next_req['batch'] = needed + next_req["batch"] = needed if expires: - next_req['expires'] = expires + next_req["expires"] = expires if heartbeat: - next_req['idle_heartbeat'] = int( + next_req["idle_heartbeat"] = int( heartbeat * 1_000_000_000 ) # to nanoseconds @@ -1313,9 +1309,7 @@ async def _fetch_n( # Wait for the rest of the messages to be delivered to the internal pending queue. try: for _ in range(needed): - deadline = JetStreamContext._time_until( - timeout, start_time - ) + deadline = JetStreamContext._time_until(timeout, start_time) if deadline is not None and deadline < 0: return msgs @@ -1360,7 +1354,7 @@ async def key_value(self, bucket: str) -> KeyValue: stream=stream, pre=KV_PRE_TEMPLATE.format(bucket=bucket), js=self, - direct=bool(si.config.allow_direct) + direct=bool(si.config.allow_direct), ) async def create_key_value( diff --git a/nats/js/errors.py b/nats/js/errors.py index de22fc25..207fe949 100644 --- a/nats/js/errors.py +++ b/nats/js/errors.py @@ -33,7 +33,7 @@ def __init__(self, description: Optional[str] = None) -> None: self.description = description def __str__(self) -> str: - desc = '' + desc = "" if self.description: desc = self.description return f"nats: JetStream.{self.__class__.__name__} {desc}" @@ -44,6 +44,7 @@ class APIError(Error): """ An Error that is the result of interacting with NATS JetStream. """ + code: Optional[int] err_code: Optional[int] description: Optional[str] @@ -77,7 +78,7 @@ def from_msg(cls, msg: Msg) -> NoReturn: @classmethod def from_error(cls, err: Dict[str, Any]): - code = err['code'] + code = err["code"] if code == 503: raise ServiceUnavailableError(**err) elif code == 500: @@ -100,6 +101,7 @@ class ServiceUnavailableError(APIError): """ A 503 error """ + pass @@ -107,6 +109,7 @@ class ServerError(APIError): """ A 500 error """ + pass @@ -114,6 +117,7 @@ class NotFoundError(APIError): """ A 404 error """ + pass @@ -121,6 +125,7 @@ class BadRequestError(APIError): """ A 400 error. """ + pass @@ -161,7 +166,7 @@ def __init__( self, stream_resume_sequence=None, consumer_sequence=None, - last_consumer_sequence=None + last_consumer_sequence=None, ) -> None: self.stream_resume_sequence = stream_resume_sequence self.consumer_sequence = consumer_sequence @@ -179,6 +184,7 @@ class BucketNotFoundError(NotFoundError): """ When attempted to bind to a JetStream KeyValue that does not exist. """ + pass @@ -190,6 +196,7 @@ class KeyValueError(APIError): """ Raised when there is an issue interacting with the KeyValue store. """ + pass @@ -236,13 +243,11 @@ def __str__(self) -> str: class NoKeysError(KeyValueError): - def __str__(self) -> str: return "nats: no keys found" class KeyHistoryTooLargeError(KeyValueError): - def __str__(self) -> str: return "nats: history limited to a max of 64" @@ -251,6 +256,7 @@ class InvalidBucketNameError(Error): """ Raised when trying to create a KV or OBJ bucket with invalid name. """ + pass @@ -258,6 +264,7 @@ class InvalidObjectNameError(Error): """ Raised when trying to put an object in Object Store with invalid key. """ + pass @@ -265,6 +272,7 @@ class BadObjectMetaError(Error): """ Raised when trying to read corrupted metadata from Object Store. """ + pass @@ -272,6 +280,7 @@ class LinkIsABucketError(Error): """ Raised when trying to get object from Object Store that is a bucket. """ + pass @@ -279,6 +288,7 @@ class DigestMismatchError(Error): """ Raised when getting an object from Object Store that has a different digest than expected. """ + pass @@ -286,6 +296,7 @@ class ObjectNotFoundError(NotFoundError): """ When attempted to lookup an Object that does not exist. """ + pass @@ -293,6 +304,7 @@ class ObjectDeletedError(NotFoundError): """ When attempted to do an operation to an Object that does not exist. """ + pass @@ -300,4 +312,5 @@ class ObjectAlreadyExists(Error): """ When attempted to do an operation to an Object that already exist. """ + pass diff --git a/nats/js/kv.py b/nats/js/kv.py index 9ab077d1..5e6efd5a 100644 --- a/nats/js/kv.py +++ b/nats/js/kv.py @@ -72,6 +72,7 @@ class Entry: """ An entry from a KeyValue store in JetStream. """ + bucket: str key: str value: Optional[bytes] @@ -85,6 +86,7 @@ class BucketStatus: """ BucketStatus is the status of a KeyValue bucket. """ + stream_info: api.StreamInfo bucket: str @@ -218,9 +220,7 @@ async def create(self, key: str, value: bytes) -> int: return pa - async def update( - self, key: str, value: bytes, last: Optional[int] = None - ) -> int: + async def update(self, key: str, value: bytes, last: Optional[int] = None) -> int: """ update will update the value iff the latest revision matches. """ @@ -231,9 +231,7 @@ async def update( pa = None try: - pa = await self._js.publish( - f"{self._pre}{key}", value, headers=hdrs - ) + pa = await self._js.publish(f"{self._pre}{key}", value, headers=hdrs) except nats.js.errors.APIError as err: # Check for a BadRequest::KeyWrongLastSequenceError error code. if err.err_code == 10071: @@ -285,9 +283,7 @@ async def purge_deletes(self, olderthan: int = 30 * 60) -> bool: duration = datetime.datetime.now() - entry.created if olderthan > 0 and olderthan > duration.total_seconds(): keep = 1 - await self._js.purge_stream( - self._stream, subject=subject, keep=keep - ) + await self._js.purge_stream(self._stream, subject=subject, keep=keep) return True async def status(self) -> BucketStatus: @@ -298,10 +294,11 @@ async def status(self) -> BucketStatus: return KeyValue.BucketStatus(stream_info=info, bucket=self._name) class KeyWatcher: - def __init__(self, js): self._js = js - self._updates: asyncio.Queue[KeyValue.Entry | None] = asyncio.Queue(maxsize=256) + self._updates: asyncio.Queue[KeyValue.Entry | None] = asyncio.Queue( + maxsize=256 + ) self._sub = None self._pending: Optional[int] = None @@ -432,7 +429,7 @@ async def watch_updates(msg): # keys() uses this if ignore_deletes: - if (op == KV_PURGE or op == KV_DEL): + if op == KV_PURGE or op == KV_DEL: if meta.num_pending == 0 and not watcher._init_done: await watcher._updates.put(None) watcher._init_done = True @@ -440,7 +437,7 @@ async def watch_updates(msg): entry = KeyValue.Entry( bucket=self._name, - key=msg.subject[len(self._pre):], + key=msg.subject[len(self._pre) :], value=msg.data, revision=meta.sequence.stream, delta=meta.num_pending, diff --git a/nats/js/manager.py b/nats/js/manager.py index 98648655..9a2be323 100644 --- a/nats/js/manager.py +++ b/nats/js/manager.py @@ -26,9 +26,9 @@ if TYPE_CHECKING: from nats import NATS -NATS_HDR_LINE = bytearray(b'NATS/1.0') +NATS_HDR_LINE = bytearray(b"NATS/1.0") NATS_HDR_LINE_SIZE = len(NATS_HDR_LINE) -_CRLF_ = b'\r\n' +_CRLF_ = b"\r\n" _CRLF_LEN_ = len(_CRLF_) @@ -50,7 +50,7 @@ def __init__( async def account_info(self) -> api.AccountInfo: resp = await self._api_request( - f"{self._prefix}.INFO", b'', timeout=self._timeout + f"{self._prefix}.INFO", b"", timeout=self._timeout ) return api.AccountInfo.from_response(resp) @@ -64,26 +64,28 @@ async def find_stream_name_by_subject(self, subject: str) -> str: info = await self._api_request( req_sub, req_data.encode(), timeout=self._timeout ) - if not info['streams']: + if not info["streams"]: raise NotFoundError - return info['streams'][0] + return info["streams"][0] - async def stream_info(self, name: str, subjects_filter: Optional[str] = None) -> api.StreamInfo: + async def stream_info( + self, name: str, subjects_filter: Optional[str] = None + ) -> api.StreamInfo: """ Get the latest StreamInfo by stream name. """ - req_data = '' + req_data = "" if subjects_filter: req_data = json.dumps({"subjects_filter": subjects_filter}) resp = await self._api_request( - f"{self._prefix}.STREAM.INFO.{name}", req_data.encode(), timeout=self._timeout + f"{self._prefix}.STREAM.INFO.{name}", + req_data.encode(), + timeout=self._timeout, ) return api.StreamInfo.from_response(resp) async def add_stream( - self, - config: Optional[api.StreamConfig] = None, - **params + self, config: Optional[api.StreamConfig] = None, **params ) -> api.StreamInfo: """ add_stream creates a stream. @@ -117,9 +119,7 @@ async def add_stream( return api.StreamInfo.from_response(resp) async def update_stream( - self, - config: Optional[api.StreamConfig] = None, - **params + self, config: Optional[api.StreamConfig] = None, **params ) -> api.StreamInfo: """ update_stream updates a stream. @@ -145,33 +145,31 @@ async def delete_stream(self, name: str) -> bool: resp = await self._api_request( f"{self._prefix}.STREAM.DELETE.{name}", timeout=self._timeout ) - return resp['success'] + return resp["success"] async def purge_stream( self, name: str, seq: Optional[int] = None, subject: Optional[str] = None, - keep: Optional[int] = None + keep: Optional[int] = None, ) -> bool: """ Purge a stream by name. """ stream_req: Dict[str, Any] = {} if seq: - stream_req['seq'] = seq + stream_req["seq"] = seq if subject: - stream_req['filter'] = subject + stream_req["filter"] = subject if keep: - stream_req['keep'] = keep + stream_req["keep"] = keep req = json.dumps(stream_req) resp = await self._api_request( - f"{self._prefix}.STREAM.PURGE.{name}", - req.encode(), - timeout=self._timeout + f"{self._prefix}.STREAM.PURGE.{name}", req.encode(), timeout=self._timeout ) - return resp['success'] + return resp["success"] async def consumer_info( self, stream: str, consumer: str, timeout: Optional[float] = None @@ -180,9 +178,7 @@ async def consumer_info( if timeout is None: timeout = self._timeout resp = await self._api_request( - f"{self._prefix}.CONSUMER.INFO.{stream}.{consumer}", - b'', - timeout=timeout + f"{self._prefix}.CONSUMER.INFO.{stream}.{consumer}", b"", timeout=timeout ) return api.ConsumerInfo.from_response(resp) @@ -196,7 +192,7 @@ async def streams_info(self, offset=0) -> List[api.StreamInfo]: timeout=self._timeout, ) streams = [] - for stream in resp['streams']: + for stream in resp["streams"]: stream_info = api.StreamInfo.from_response(stream) streams.append(stream_info) return streams @@ -230,7 +226,7 @@ async def add_consumer( req_data = json.dumps(req).encode() resp = None - subject = '' + subject = "" version = self._nc.connected_server_version consumer_name_supported = version.major >= 2 and version.minor >= 9 if consumer_name_supported and config.name: @@ -252,15 +248,13 @@ async def add_consumer( async def delete_consumer(self, stream: str, consumer: str) -> bool: resp = await self._api_request( f"{self._prefix}.CONSUMER.DELETE.{stream}.{consumer}", - b'', - timeout=self._timeout + b"", + timeout=self._timeout, ) - return resp['success'] + return resp["success"] async def consumers_info( - self, - stream: str, - offset: Optional[int] = None + self, stream: str, offset: Optional[int] = None ) -> List[api.ConsumerInfo]: """ consumers_info retrieves a list of consumers. Consumers list limit is 256 for more @@ -270,13 +264,11 @@ async def consumers_info( """ resp = await self._api_request( f"{self._prefix}.CONSUMER.LIST.{stream}", - b'' if offset is None else json.dumps({ - "offset": offset - }).encode(), + b"" if offset is None else json.dumps({"offset": offset}).encode(), timeout=self._timeout, ) consumers = [] - for consumer in resp['consumers']: + for consumer in resp["consumers"]: consumer_info = api.ConsumerInfo.from_response(consumer) consumers.append(consumer_info) return consumers @@ -295,23 +287,23 @@ async def get_msg( req_subject = None req: Dict[str, Any] = {} if seq: - req['seq'] = seq + req["seq"] = seq if subject: - req['seq'] = None - req.pop('seq', None) - req['last_by_subj'] = subject + req["seq"] = None + req.pop("seq", None) + req["last_by_subj"] = subject if next: - req['seq'] = seq - req['last_by_subj'] = None - req.pop('last_by_subj', None) - req['next_by_subj'] = subject + req["seq"] = seq + req["last_by_subj"] = None + req.pop("last_by_subj", None) + req["next_by_subj"] = subject data = json.dumps(req) if direct: # $JS.API.DIRECT.GET.KV_{stream_name}.$KV.TEST.{key} if subject and (seq is None): # last_by_subject type request requires no payload. - data = '' + data = "" req_subject = f"{self._prefix}.DIRECT.GET.{stream_name}.{subject}" else: req_subject = f"{self._prefix}.DIRECT.GET.{stream_name}" @@ -328,10 +320,10 @@ async def get_msg( req_subject, data.encode(), timeout=self._timeout ) - raw_msg = api.RawStreamMsg.from_response(resp_data['message']) + raw_msg = api.RawStreamMsg.from_response(resp_data["message"]) if raw_msg.hdrs: hdrs = base64.b64decode(raw_msg.hdrs) - raw_headers = hdrs[NATS_HDR_LINE_SIZE + _CRLF_LEN_:] + raw_headers = hdrs[NATS_HDR_LINE_SIZE + _CRLF_LEN_ :] parsed_headers = self._hdr_parser.parsebytes(raw_headers) headers = None if len(parsed_headers.items()) > 0: @@ -351,18 +343,18 @@ async def get_msg( def _lift_msg_to_raw_msg(self, msg) -> api.RawStreamMsg: if not msg.data: msg.data = None - status = msg.headers.get('Status') + status = msg.headers.get("Status") if status: - if status == '404': + if status == "404": raise NotFoundError else: raise APIError.from_msg(msg) raw_msg = api.RawStreamMsg() - subject = msg.headers['Nats-Subject'] + subject = msg.headers["Nats-Subject"] raw_msg.subject = subject - seq = msg.headers.get('Nats-Sequence') + seq = msg.headers.get("Nats-Sequence") if seq: raw_msg.seq = int(seq) raw_msg.data = msg.data @@ -375,10 +367,10 @@ async def delete_msg(self, stream_name: str, seq: int) -> bool: delete_msg retrieves a message from a stream based on the sequence ID. """ req_subject = f"{self._prefix}.STREAM.MSG.DELETE.{stream_name}" - req = {'seq': seq} + req = {"seq": seq} data = json.dumps(req) resp = await self._api_request(req_subject, data.encode()) - return resp['success'] + return resp["success"] async def get_last_msg( self, @@ -394,7 +386,7 @@ async def get_last_msg( async def _api_request( self, req_subject: str, - req: bytes = b'', + req: bytes = b"", timeout: float = 5, ) -> Dict[str, Any]: try: @@ -404,7 +396,7 @@ async def _api_request( raise ServiceUnavailableError # Check for API errors. - if 'error' in resp: - raise APIError.from_error(resp['error']) + if "error" in resp: + raise APIError.from_error(resp["error"]) return resp diff --git a/nats/js/object_store.py b/nats/js/object_store.py index 5eb10612..db38de7b 100644 --- a/nats/js/object_store.py +++ b/nats/js/object_store.py @@ -25,8 +25,13 @@ import nats.errors from nats.js import api from nats.js.errors import ( - BadObjectMetaError, DigestMismatchError, ObjectAlreadyExists, - ObjectDeletedError, ObjectNotFoundError, NotFoundError, LinkIsABucketError + BadObjectMetaError, + DigestMismatchError, + ObjectAlreadyExists, + ObjectDeletedError, + ObjectNotFoundError, + NotFoundError, + LinkIsABucketError, ) from nats.js.kv import MSG_ROLLUP_SUBJECT @@ -71,6 +76,7 @@ class ObjectStoreStatus: """ ObjectStoreStatus is the status of a ObjectStore bucket. """ + stream_info: api.StreamInfo bucket: str @@ -140,7 +146,7 @@ async def get_info( meta = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode(), ) stream = OBJ_STREAM_TEMPLATE.format(bucket=self._name) @@ -196,12 +202,8 @@ async def get( if info.size == 0: return result - chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format( - bucket=self._name, obj=info.nuid - ) - sub = await self._js.subscribe( - subject=chunk_subj, ordered_consumer=True - ) + chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format(bucket=self._name, obj=info.nuid) + sub = await self._js.subscribe(subject=chunk_subj, ordered_consumer=True) h = sha256() @@ -209,7 +211,7 @@ async def get( executor_fn = None if writeinto: executor = asyncio.get_running_loop().run_in_executor - if hasattr(writeinto, 'buffer'): + if hasattr(writeinto, "buffer"): executor_fn = writeinto.buffer.write else: executor_fn = writeinto.write @@ -281,10 +283,10 @@ async def put( data = io.BytesIO(data.encode()) elif isinstance(data, bytes): data = io.BytesIO(data) - elif hasattr(data, 'readinto') or isinstance(data, io.BufferedIOBase): + elif hasattr(data, "readinto") or isinstance(data, io.BufferedIOBase): # Need to delegate to a threaded executor to avoid blocking. executor = asyncio.get_running_loop().run_in_executor - elif hasattr(data, 'buffer') or isinstance(data, io.TextIOWrapper): + elif hasattr(data, "buffer") or isinstance(data, io.TextIOWrapper): data = data.buffer else: raise TypeError("nats: invalid type for object store") @@ -298,7 +300,7 @@ async def put( nuid=newnuid.decode(), size=0, chunks=0, - mtime=datetime.now(timezone.utc).isoformat() + mtime=datetime.now(timezone.utc).isoformat(), ) h = sha256() chunk = bytearray(meta.options.max_chunk_size) @@ -334,14 +336,14 @@ async def put( # Prepare the meta message. meta_subj = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode(), ) # Publish the meta message. try: await self._js.publish( meta_subj, json.dumps(info.as_dict()).encode(), - headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT} + headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT}, ) except Exception as err: await self._js.purge_stream(self._stream, subject=chunk_subj) @@ -403,14 +405,14 @@ async def update_meta( # Prepare the meta message. meta_subj = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(name, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(name, "utf-8")).decode(), ) # Publish the meta message. try: await self._js.publish( meta_subj, json.dumps(info.as_dict()).encode(), - headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT} + headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT}, ) except Exception as err: raise err @@ -421,7 +423,6 @@ async def update_meta( await self._js.purge_stream(self._stream, subject=meta_subj) class ObjectWatcher: - def __init__(self, js): self._js = js self._updates = asyncio.Queue(maxsize=256) @@ -466,7 +467,9 @@ async def watch( """ watch for changes in the underlying store and receive meta information updates. """ - all_meta = OBJ_ALL_META_PRE_TEMPLATE.format(bucket=self._name, ) + all_meta = OBJ_ALL_META_PRE_TEMPLATE.format( + bucket=self._name, + ) watcher = ObjectStore.ObjectWatcher(self) async def watch_updates(msg): @@ -515,28 +518,26 @@ async def delete(self, name: str) -> ObjectResult: raise BadObjectMetaError # Purge chunks for the object. - chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format( - bucket=self._name, obj=info.nuid - ) + chunk_subj = OBJ_CHUNKS_PRE_TEMPLATE.format(bucket=self._name, obj=info.nuid) # Reset meta values. info.deleted = True info.size = 0 info.chunks = 0 - info.digest = '' - info.mtime = '' + info.digest = "" + info.mtime = "" # Prepare the meta message. meta_subj = OBJ_META_PRE_TEMPLATE.format( bucket=self._name, - obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode() + obj=base64.urlsafe_b64encode(bytes(obj, "utf-8")).decode(), ) # Publish the meta message. try: await self._js.publish( meta_subj, json.dumps(info.as_dict()).encode(), - headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT} + headers={api.Header.ROLLUP: MSG_ROLLUP_SUBJECT}, ) finally: await self._js.purge_stream(self._stream, subject=chunk_subj) diff --git a/nats/nuid.py b/nats/nuid.py index e8f4470f..7896d642 100644 --- a/nats/nuid.py +++ b/nats/nuid.py @@ -18,7 +18,7 @@ from secrets import randbelow, token_bytes from sys import maxsize as MaxInt -DIGITS = b'0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +DIGITS = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" BASE = 62 PREFIX_LENGTH = 12 SEQ_LENGTH = 10 diff --git a/nats/protocol/command.py b/nats/protocol/command.py index 80573685..95d8a9a9 100644 --- a/nats/protocol/command.py +++ b/nats/protocol/command.py @@ -2,31 +2,38 @@ from typing import Callable -PUB_OP = 'PUB' -HPUB_OP = 'HPUB' -SUB_OP = 'SUB' -UNSUB_OP = 'UNSUB' -_CRLF_ = '\r\n' +PUB_OP = "PUB" +HPUB_OP = "HPUB" +SUB_OP = "SUB" +UNSUB_OP = "UNSUB" +_CRLF_ = "\r\n" Command = Callable[..., bytes] def pub_cmd(subject, reply, payload) -> bytes: - return f'{PUB_OP} {subject} {reply} {len(payload)}{_CRLF_}'.encode( - ) + payload + _CRLF_.encode() + return ( + f"{PUB_OP} {subject} {reply} {len(payload)}{_CRLF_}".encode() + + payload + + _CRLF_.encode() + ) def hpub_cmd(subject, reply, hdr, payload) -> bytes: hdr_len = len(hdr) total_size = len(payload) + hdr_len - return f'{HPUB_OP} {subject} {reply} {hdr_len} {total_size}{_CRLF_}'.encode( - ) + hdr + payload + _CRLF_.encode() + return ( + f"{HPUB_OP} {subject} {reply} {hdr_len} {total_size}{_CRLF_}".encode() + + hdr + + payload + + _CRLF_.encode() + ) def sub_cmd(subject, queue, sid) -> bytes: - return f'{SUB_OP} {subject} {queue} {sid}{_CRLF_}'.encode() + return f"{SUB_OP} {subject} {queue} {sid}{_CRLF_}".encode() def unsub_cmd(sid, limit) -> bytes: - limit_s = '' if limit == 0 else f'{limit}' - return f'{UNSUB_OP} {sid} {limit_s}{_CRLF_}'.encode() + limit_s = "" if limit == 0 else f"{limit}" + return f"{UNSUB_OP} {sid} {limit_s}{_CRLF_}".encode() diff --git a/nats/protocol/parser.py b/nats/protocol/parser.py index 20bf3948..ac5a70cc 100644 --- a/nats/protocol/parser.py +++ b/nats/protocol/parser.py @@ -24,31 +24,31 @@ from nats.errors import ProtocolError MSG_RE = re.compile( - b'\\AMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?(\\d+)\r\n' + b"\\AMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?(\\d+)\r\n" ) HMSG_RE = re.compile( - b'\\AHMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?([\\d]+)\\s+(\\d+)\r\n' + b"\\AHMSG\\s+([^\\s]+)\\s+([^\\s]+)\\s+(([^\\s]+)[^\\S\r\n]+)?([\\d]+)\\s+(\\d+)\r\n" ) -OK_RE = re.compile(b'\\A\\+OK\\s*\r\n') -ERR_RE = re.compile(b'\\A-ERR\\s+(\'.+\')?\r\n') -PING_RE = re.compile(b'\\APING\\s*\r\n') -PONG_RE = re.compile(b'\\APONG\\s*\r\n') -INFO_RE = re.compile(b'\\AINFO\\s+([^\r\n]+)\r\n') - -INFO_OP = b'INFO' -CONNECT_OP = b'CONNECT' -PUB_OP = b'PUB' -MSG_OP = b'MSG' -HMSG_OP = b'HMSG' -SUB_OP = b'SUB' -UNSUB_OP = b'UNSUB' -PING_OP = b'PING' -PONG_OP = b'PONG' -OK_OP = b'+OK' -ERR_OP = b'-ERR' -MSG_END = b'\n' -_CRLF_ = b'\r\n' -_SPC_ = b' ' +OK_RE = re.compile(b"\\A\\+OK\\s*\r\n") +ERR_RE = re.compile(b"\\A-ERR\\s+('.+')?\r\n") +PING_RE = re.compile(b"\\APING\\s*\r\n") +PONG_RE = re.compile(b"\\APONG\\s*\r\n") +INFO_RE = re.compile(b"\\AINFO\\s+([^\r\n]+)\r\n") + +INFO_OP = b"INFO" +CONNECT_OP = b"CONNECT" +PUB_OP = b"PUB" +MSG_OP = b"MSG" +HMSG_OP = b"HMSG" +SUB_OP = b"SUB" +UNSUB_OP = b"UNSUB" +PING_OP = b"PING" +PONG_OP = b"PONG" +OK_OP = b"+OK" +ERR_OP = b"-ERR" +MSG_END = b"\n" +_CRLF_ = b"\r\n" +_SPC_ = b" " OK = OK_OP + _CRLF_ PING = PING_OP + _CRLF_ @@ -72,7 +72,6 @@ class Parser: - def __init__(self, nc=None) -> None: self.nc = nc self.reset() @@ -87,7 +86,7 @@ def reset(self) -> None: self.header_needed = 0 self.msg_arg: Dict[str, Any] = {} - async def parse(self, data: bytes = b''): + async def parse(self, data: bytes = b""): """ Parses the wire protocol from NATS for the client and dispatches the subscription callbacks. @@ -104,9 +103,9 @@ async def parse(self, data: bytes = b''): if reply: self.msg_arg["reply"] = reply else: - self.msg_arg["reply"] = b'' + self.msg_arg["reply"] = b"" self.needed = int(needed_bytes) - del self.buf[:msg.end()] + del self.buf[: msg.end()] self.state = AWAITING_MSG_PAYLOAD continue except Exception: @@ -115,17 +114,16 @@ async def parse(self, data: bytes = b''): msg = HMSG_RE.match(self.buf) if msg: try: - subject, sid, _, reply, header_size, needed_bytes = msg.groups( - ) + subject, sid, _, reply, header_size, needed_bytes = msg.groups() self.msg_arg["subject"] = subject self.msg_arg["sid"] = int(sid) if reply: self.msg_arg["reply"] = reply else: - self.msg_arg["reply"] = b'' + self.msg_arg["reply"] = b"" self.needed = int(needed_bytes) self.header_needed = int(header_size) - del self.buf[:msg.end()] + del self.buf[: msg.end()] self.state = AWAITING_MSG_PAYLOAD continue except Exception: @@ -134,7 +132,7 @@ async def parse(self, data: bytes = b''): ok = OK_RE.match(self.buf) if ok: # Do nothing and just skip. - del self.buf[:ok.end()] + del self.buf[: ok.end()] continue err = ERR_RE.match(self.buf) @@ -142,18 +140,18 @@ async def parse(self, data: bytes = b''): err_msg = err.groups() emsg = err_msg[0].decode().lower() await self.nc._process_err(emsg) - del self.buf[:err.end()] + del self.buf[: err.end()] continue ping = PING_RE.match(self.buf) if ping: - del self.buf[:ping.end()] + del self.buf[: ping.end()] await self.nc._process_ping() continue pong = PONG_RE.match(self.buf) if pong: - del self.buf[:pong.end()] + del self.buf[: pong.end()] await self.nc._process_pong() continue @@ -162,7 +160,7 @@ async def parse(self, data: bytes = b''): info_line = info.groups()[0] srv_info = json.loads(info_line.decode()) self.nc._process_info(srv_info) - del self.buf[:info.end()] + del self.buf[: info.end()] continue if len(self.buf) < MAX_CONTROL_LINE_SIZE and _CRLF_ in self.buf: @@ -186,21 +184,17 @@ async def parse(self, data: bytes = b''): # Consume msg payload from buffer and set next parser state. if self.header_needed > 0: - hbuf = bytes(self.buf[:self.header_needed]) - payload = bytes( - self.buf[self.header_needed:self.needed] - ) + hbuf = bytes(self.buf[: self.header_needed]) + payload = bytes(self.buf[self.header_needed : self.needed]) hdr = hbuf - del self.buf[:self.needed + CRLF_SIZE] + del self.buf[: self.needed + CRLF_SIZE] self.header_needed = 0 else: - payload = bytes(self.buf[:self.needed]) - del self.buf[:self.needed + CRLF_SIZE] + payload = bytes(self.buf[: self.needed]) + del self.buf[: self.needed + CRLF_SIZE] self.state = AWAITING_CONTROL_LINE - await self.nc._process_msg( - sid, subject, reply, payload, hdr - ) + await self.nc._process_msg(sid, subject, reply, payload, hdr) else: # Wait until we have enough bytes in buffer. break diff --git a/tests/test_client.py b/tests/test_client.py index 5548aa07..c3d58271 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -29,7 +29,6 @@ class ClientUtilsTest(unittest.TestCase): - def test_default_connect_command(self): nc = NATS() nc.options["verbose"] = False @@ -54,14 +53,13 @@ def test_default_connect_command_with_name(self): class ClientTest(SingleServerTestCase): - @async_test async def test_default_connect(self): nc = await nats.connect() - self.assertIn('server_id', nc._server_info) - self.assertIn('client_id', nc._server_info) - self.assertIn('max_payload', nc._server_info) - self.assertEqual(nc._server_info['max_payload'], nc.max_payload) + self.assertIn("server_id", nc._server_info) + self.assertIn("client_id", nc._server_info) + self.assertIn("max_payload", nc._server_info) + self.assertEqual(nc._server_info["max_payload"], nc.max_payload) self.assertTrue(nc.max_payload > 0) self.assertTrue(nc.is_connected) self.assertTrue(nc.client_id > 0) @@ -82,10 +80,9 @@ async def test_default_module_connect(self): def test_connect_syntax_sugar(self): nc = NATS() - nc._setup_server_pool([ - "nats://127.0.0.1:4222", "nats://127.0.0.1:4223", - "nats://127.0.0.1:4224" - ]) + nc._setup_server_pool( + ["nats://127.0.0.1:4222", "nats://127.0.0.1:4223", "nats://127.0.0.1:4224"] + ) self.assertEqual(3, len(nc._server_pool)) nc = NATS() @@ -214,34 +211,34 @@ async def test_publish(self): nc = NATS() await nc.connect() for i in range(0, 100): - await nc.publish(f"hello.{i}", b'A') + await nc.publish(f"hello.{i}", b"A") with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") await nc.flush() await nc.close() await asyncio.sleep(1) - self.assertEqual(100, nc.stats['out_msgs']) - self.assertEqual(100, nc.stats['out_bytes']) + self.assertEqual(100, nc.stats["out_msgs"]) + self.assertEqual(100, nc.stats["out_bytes"]) - endpoint = f'127.0.0.1:{self.server_pool[0].http_port}' + endpoint = f"127.0.0.1:{self.server_pool[0].http_port}" httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/varz') + httpclient.request("GET", "/varz") response = httpclient.getresponse() varz = json.loads((response.read()).decode()) - self.assertEqual(100, varz['in_msgs']) - self.assertEqual(100, varz['in_bytes']) + self.assertEqual(100, varz["in_msgs"]) + self.assertEqual(100, varz["in_bytes"]) @async_test async def test_flush(self): nc = NATS() await nc.connect() for i in range(0, 10): - await nc.publish(f"flush.{i}", b'AA') + await nc.publish(f"flush.{i}", b"AA") await nc.flush() - self.assertEqual(10, nc.stats['out_msgs']) - self.assertEqual(20, nc.stats['out_bytes']) + self.assertEqual(10, nc.stats["out_msgs"]) + self.assertEqual(20, nc.stats["out_bytes"]) await nc.close() @async_test @@ -252,14 +249,14 @@ async def test_subscribe(self): async def subscription_handler(msg): msgs.append(msg) - payload = b'hello world' + payload = b"hello world" await nc.connect() sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") # Validate some of the subjects with self.assertRaises(nats.errors.BadSubjectError): @@ -279,8 +276,8 @@ async def subscription_handler(msg): self.assertEqual(1, len(msgs)) msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(sub._id, msg.sid) self.assertEqual(1, sub._received) @@ -291,17 +288,17 @@ async def subscription_handler(msg): with self.assertRaises(KeyError): nc._subs[sub._id] - self.assertEqual(1, nc.stats['in_msgs']) - self.assertEqual(11, nc.stats['in_bytes']) - self.assertEqual(2, nc.stats['out_msgs']) - self.assertEqual(22, nc.stats['out_bytes']) + self.assertEqual(1, nc.stats["in_msgs"]) + self.assertEqual(11, nc.stats["in_bytes"]) + self.assertEqual(2, nc.stats["out_msgs"]) + self.assertEqual(22, nc.stats["out_bytes"]) - endpoint = f'127.0.0.1:{self.server_pool[0].http_port}' + endpoint = f"127.0.0.1:{self.server_pool[0].http_port}" httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/connz') + httpclient.request("GET", "/connz") response = httpclient.getresponse() connz = json.loads((response.read()).decode()) - self.assertEqual(0, len(connz['connections'])) + self.assertEqual(0, len(connz["connections"])) @async_test async def test_subscribe_functools_partial(self): @@ -316,11 +313,9 @@ async def subscription_handler(arg1, msg): partial_arg = arg1 msgs.append(msg) - partial_sub_handler = functools.partial( - subscription_handler, "example" - ) + partial_sub_handler = functools.partial(subscription_handler, "example") - payload = b'hello world' + payload = b"hello world" await nc.connect() sid = await nc.subscribe("foo", cb=partial_sub_handler) @@ -331,8 +326,8 @@ async def subscription_handler(arg1, msg): self.assertEqual(1, len(msgs)) self.assertEqual(partial_arg, "example") msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(sid._id, msg.sid) @@ -361,7 +356,7 @@ async def subscription_handler2(msg): await nc.flush() await nc2.flush() - payload = b'hello world' + payload = b"hello world" for i in range(0, 10): await nc.publish("foo", payload) await asyncio.sleep(0.1) @@ -377,15 +372,15 @@ async def subscription_handler2(msg): await nc.close() await nc2.close() - self.assertEqual(0, nc.stats['in_msgs']) - self.assertEqual(0, nc.stats['in_bytes']) - self.assertEqual(10, nc.stats['out_msgs']) - self.assertEqual(110, nc.stats['out_bytes']) + self.assertEqual(0, nc.stats["in_msgs"]) + self.assertEqual(0, nc.stats["in_bytes"]) + self.assertEqual(10, nc.stats["out_msgs"]) + self.assertEqual(110, nc.stats["out_bytes"]) - self.assertEqual(10, nc2.stats['in_msgs']) - self.assertEqual(110, nc2.stats['in_bytes']) - self.assertEqual(0, nc2.stats['out_msgs']) - self.assertEqual(0, nc2.stats['out_bytes']) + self.assertEqual(10, nc2.stats["in_msgs"]) + self.assertEqual(110, nc2.stats["in_bytes"]) + self.assertEqual(0, nc2.stats["out_msgs"]) + self.assertEqual(0, nc2.stats["out_bytes"]) @async_test async def test_invalid_subscribe_error(self): @@ -426,7 +421,7 @@ async def subscription_handler(msg): await nc.subscribe("tests.>", cb=subscription_handler) for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # Wait a bit for messages to be received. await asyncio.sleep(1) @@ -447,14 +442,14 @@ async def iterator_func(sub): fut.set_result(None) await nc.connect() - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") self.assertFalse(sub._message_iterator._unsubscribed_future.done()) asyncio.ensure_future(iterator_func(sub)) self.assertFalse(sub._message_iterator._unsubscribed_future.done()) for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await asyncio.sleep(0) await asyncio.wait_for(sub.drain(), 1) @@ -477,12 +472,12 @@ async def test_subscribe_iterate_unsub_comprehension(self): await nc.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await sub.unsubscribe(limit=2) await nc.flush() for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # It should unblock once the auto unsubscribe limit is hit. msgs = [msg async for msg in sub.messages] @@ -500,12 +495,12 @@ async def test_subscribe_iterate_unsub(self): await nc.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await sub.unsubscribe(limit=2) await nc.flush() for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # A couple of messages would be received then this will unblock. msgs = [] @@ -528,12 +523,12 @@ async def test_subscribe_auto_unsub(self): async def handler(msg): msgs.append(msg) - sub = await nc.subscribe('tests.>', cb=handler) + sub = await nc.subscribe("tests.>", cb=handler) await sub.unsubscribe(limit=2) await nc.flush() for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await asyncio.sleep(1) self.assertEqual(2, len(msgs)) @@ -547,7 +542,7 @@ async def test_subscribe_iterate_next_msg(self): await nc.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await nc.flush() # Async generator to consume messages. @@ -561,7 +556,7 @@ async def next_msg(): return msg for i in range(0, 2): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # A couple of messages would be received then this will unblock. msg = await next_msg() @@ -576,9 +571,9 @@ async def next_msg(): # FIXME: This message would be lost because cannot # reuse the future from the iterator that timed out. - await nc.publish(f"tests.2", b'bar') + await nc.publish(f"tests.2", b"bar") - await nc.publish(f"tests.3", b'bar') + await nc.publish(f"tests.3", b"bar") await nc.flush() # FIXME: this test is flaky @@ -596,11 +591,11 @@ async def test_subscribe_next_msg(self): nc = await nats.connect() # Make subscription that only expects a couple of messages. - sub = await nc.subscribe('tests.>') + sub = await nc.subscribe("tests.>") await nc.flush() for i in range(0, 2): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await nc.flush() # A couple of messages would be received then this will unblock. @@ -624,8 +619,8 @@ async def test_subscribe_next_msg(self): await sub.next_msg(timeout=0.5) # Send again a couple of messages. - await nc.publish(f"tests.2", b'bar') - await nc.publish(f"tests.3", b'bar') + await nc.publish(f"tests.2", b"bar") + await nc.publish(f"tests.3", b"bar") await nc.flush() msg = await sub.next_msg() self.assertEqual("tests.2", msg.subject) @@ -655,14 +650,14 @@ async def error_cb(err): nc = await nats.connect(error_cb=error_cb) sub = await nc.subscribe( - 'tests.>', + "tests.>", pending_msgs_limit=5, pending_bytes_limit=-1, ) await nc.flush() for i in range(0, 6): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await nc.flush() # A couple of messages would be received then this will unblock. @@ -690,13 +685,13 @@ async def test_subscribe_next_msg_with_cb_not_supported(self): nc = await nats.connect() async def handler(msg): - await msg.respond(b'OK') + await msg.respond(b"OK") - sub = await nc.subscribe('foo', cb=handler) + sub = await nc.subscribe("foo", cb=handler) await nc.flush() for i in range(0, 2): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") await nc.flush() with self.assertRaises(nats.errors.Error): @@ -732,42 +727,42 @@ async def subscription_handler(msg): await nc.connect() sub = await nc.subscribe("foo", cb=subscription_handler) - await nc.publish("foo", b'A') - await nc.publish("foo", b'B') + await nc.publish("foo", b"A") + await nc.publish("foo", b"B") # Wait a bit to receive the messages await asyncio.sleep(0.5) self.assertEqual(2, len(msgs)) await sub.unsubscribe() - await nc.publish("foo", b'C') - await nc.publish("foo", b'D') + await nc.publish("foo", b"C") + await nc.publish("foo", b"D") # Ordering should be preserved in these at least - self.assertEqual(b'A', msgs[0].data) - self.assertEqual(b'B', msgs[1].data) + self.assertEqual(b"A", msgs[0].data) + self.assertEqual(b"B", msgs[1].data) # Should not exist by now with self.assertRaises(KeyError): nc._subs[sub._id].received await asyncio.sleep(1) - endpoint = f'127.0.0.1:{self.server_pool[0].http_port}' + endpoint = f"127.0.0.1:{self.server_pool[0].http_port}" httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/connz') + httpclient.request("GET", "/connz") response = httpclient.getresponse() connz = json.loads((response.read()).decode()) - self.assertEqual(1, len(connz['connections'])) - self.assertEqual(0, connz['connections'][0]['subscriptions']) - self.assertEqual(4, connz['connections'][0]['in_msgs']) - self.assertEqual(4, connz['connections'][0]['in_bytes']) - self.assertEqual(2, connz['connections'][0]['out_msgs']) - self.assertEqual(2, connz['connections'][0]['out_bytes']) + self.assertEqual(1, len(connz["connections"])) + self.assertEqual(0, connz["connections"][0]["subscriptions"]) + self.assertEqual(4, connz["connections"][0]["in_msgs"]) + self.assertEqual(4, connz["connections"][0]["in_bytes"]) + self.assertEqual(2, connz["connections"][0]["out_msgs"]) + self.assertEqual(2, connz["connections"][0]["out_bytes"]) await nc.close() - self.assertEqual(2, nc.stats['in_msgs']) - self.assertEqual(2, nc.stats['in_bytes']) - self.assertEqual(4, nc.stats['out_msgs']) - self.assertEqual(4, nc.stats['out_bytes']) + self.assertEqual(2, nc.stats["in_msgs"]) + self.assertEqual(2, nc.stats["in_bytes"]) + self.assertEqual(4, nc.stats["out_msgs"]) + self.assertEqual(4, nc.stats["out_bytes"]) @async_test async def test_old_style_request(self): @@ -779,32 +774,26 @@ async def worker_handler(msg): nonlocal counter counter += 1 msgs.append(msg) - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) async def slow_worker_handler(msg): await asyncio.sleep(0.5) - await nc.publish(msg.reply, b'timeout by now...') + await nc.publish(msg.reply, b"timeout by now...") await nc.connect() await nc.subscribe("help", cb=worker_handler) await nc.subscribe("slow.help", cb=slow_worker_handler) - response = await nc.request( - "help", b'please', timeout=1, old_style=True - ) - self.assertEqual(b'Reply:1', response.data) - response = await nc.request( - "help", b'please', timeout=1, old_style=True - ) - self.assertEqual(b'Reply:2', response.data) + response = await nc.request("help", b"please", timeout=1, old_style=True) + self.assertEqual(b"Reply:1", response.data) + response = await nc.request("help", b"please", timeout=1, old_style=True) + self.assertEqual(b"Reply:2", response.data) with self.assertRaises(nats.errors.TimeoutError): - msg = await nc.request( - "slow.help", b'please', timeout=0.1, old_style=True - ) + msg = await nc.request("slow.help", b"please", timeout=0.1, old_style=True) with self.assertRaises(nats.errors.NoRespondersError): - await nc.request("nowhere", b'please', timeout=0.1, old_style=True) + await nc.request("nowhere", b"please", timeout=0.1, old_style=True) await asyncio.sleep(1) await nc.close() @@ -819,11 +808,11 @@ async def worker_handler(msg): nonlocal counter counter += 1 msgs.append(msg) - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) async def slow_worker_handler(msg): await asyncio.sleep(0.5) - await nc.publish(msg.reply, b'timeout by now...') + await nc.publish(msg.reply, b"timeout by now...") errs = [] @@ -834,13 +823,13 @@ async def err_cb(err): await nc.subscribe("help", cb=worker_handler) await nc.subscribe("slow.help", cb=slow_worker_handler) - response = await nc.request("help", b'please', timeout=1) - self.assertEqual(b'Reply:1', response.data) - response = await nc.request("help", b'please', timeout=1) - self.assertEqual(b'Reply:2', response.data) + response = await nc.request("help", b"please", timeout=1) + self.assertEqual(b"Reply:1", response.data) + response = await nc.request("help", b"please", timeout=1) + self.assertEqual(b"Reply:2", response.data) with self.assertRaises(nats.errors.TimeoutError): - await nc.request("slow.help", b'please', timeout=0.1) + await nc.request("slow.help", b"please", timeout=0.1) await asyncio.sleep(1) assert len(errs) == 0 await nc.close() @@ -850,18 +839,18 @@ async def test_requests_gather(self): nc = await nats.connect() async def worker_handler(msg): - await msg.respond(b'OK') + await msg.respond(b"OK") await nc.subscribe("foo.*", cb=worker_handler) - await nc.request("foo.A", b'') + await nc.request("foo.A", b"") msgs = await asyncio.gather( - nc.request("foo.B", b''), - nc.request("foo.C", b''), - nc.request("foo.D", b''), - nc.request("foo.E", b''), - nc.request("foo.F", b''), + nc.request("foo.B", b""), + nc.request("foo.C", b""), + nc.request("foo.D", b""), + nc.request("foo.E", b""), + nc.request("foo.F", b""), ) self.assertEqual(len(msgs), 5) @@ -872,12 +861,12 @@ async def test_custom_inbox_prefix(self): nc = NATS() async def worker_handler(msg): - self.assertTrue(msg.reply.startswith('bar.')) + self.assertTrue(msg.reply.startswith("bar.")) await msg.respond(b"OK") await nc.connect(inbox_prefix="bar") await nc.subscribe("foo", cb=worker_handler) - await nc.request("foo", b'') + await nc.request("foo", b"") await nc.close() @async_test @@ -886,15 +875,15 @@ async def test_msg_respond(self): msgs = [] async def cb1(msg): - await msg.respond(b'bar') + await msg.respond(b"bar") async def cb2(msg): msgs.append(msg) await nc.connect() - await nc.subscribe('subj1', cb=cb1) - await nc.subscribe('subj2', cb=cb2) - await nc.publish('subj1', b'foo', reply='subj2') + await nc.subscribe("subj1", cb=cb1) + await nc.subscribe("subj2", cb=cb2) + await nc.publish("subj1", b"foo", reply="subj2") await asyncio.sleep(1) self.assertEqual(1, len(msgs)) @@ -906,7 +895,7 @@ async def test_pending_data_size_tracking(self): await nc.connect() largest_pending_data_size = 0 for i in range(0, 100): - await nc.publish("example", b'A' * 100000) + await nc.publish("example", b"A" * 100000) if nc.pending_data_size > 0: largest_pending_data_size = nc.pending_data_size self.assertTrue(largest_pending_data_size > 0) @@ -938,23 +927,23 @@ async def err_cb(e): err_count += 1 options = { - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, } await nc.connect(**options) await nc.close() with self.assertRaises(nats.errors.ConnectionClosedError): - await nc.publish("foo", b'A') + await nc.publish("foo", b"A") with self.assertRaises(nats.errors.ConnectionClosedError): await nc.subscribe("bar", "workers") with self.assertRaises(nats.errors.ConnectionClosedError): - await nc.publish("bar", b'B', reply="inbox") + await nc.publish("bar", b"B", reply="inbox") with self.assertRaises(nats.errors.ConnectionClosedError): await nc.flush() @@ -998,11 +987,11 @@ async def closed_cb(): closed_count += 1 options = { - 'dont_randomize': True, - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'reconnect_time_wait': 0.01 + "dont_randomize": True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "reconnect_time_wait": 0.01, } await nc.connect(**options) @@ -1022,7 +1011,7 @@ async def receiver_cb(msg): await nc2.flush() for i in range(0, 200): - await nc.publish(f"example.{i}", b'A' * 20) + await nc.publish(f"example.{i}", b"A" * 20) # All pending messages should have been emitted to the server # by the first connection at this point. @@ -1035,21 +1024,20 @@ async def receiver_cb(msg): class ClientReconnectTest(MultiServerAuthTestCase): - @async_test async def test_connect_with_auth(self): nc = NATS() options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ] } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) - self.assertIn('max_payload', nc._server_info) - self.assertEqual(nc._server_info['max_payload'], nc._max_payload) + self.assertIn("auth_required", nc._server_info) + self.assertIn("max_payload", nc._server_info) + self.assertEqual(nc._server_info["max_payload"], nc._max_payload) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1064,9 +1052,7 @@ async def test_module_connect_with_auth(self): @async_test async def test_module_connect_with_options(self): - nc = await nats.connect( - "nats://127.0.0.1:4223", user="foo", password="bar" - ) + nc = await nats.connect("nats://127.0.0.1:4223", user="foo", password="bar") self.assertTrue(nc.is_connected) await nc.drain() self.assertTrue(nc.is_closed) @@ -1082,22 +1068,24 @@ async def err_cb(e): nc = NATS() options = { - 'reconnect_time_wait': 0.2, - 'servers': ["nats://hello:world@127.0.0.1:4223", ], - 'max_reconnect_attempts': 3, - 'error_cb': err_cb, + "reconnect_time_wait": 0.2, + "servers": [ + "nats://hello:world@127.0.0.1:4223", + ], + "max_reconnect_attempts": 3, + "error_cb": err_cb, } try: await nc.connect(**options) except: pass - self.assertIn('auth_required', nc._server_info) - self.assertTrue(nc._server_info['auth_required']) + self.assertIn("auth_required", nc._server_info) + self.assertTrue(nc._server_info["auth_required"]) self.assertFalse(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) - self.assertEqual(0, nc.stats['reconnects']) + self.assertEqual(0, nc.stats["reconnects"]) self.assertTrue(len(errors) > 0) @async_test @@ -1109,10 +1097,12 @@ async def err_cb(e): errors.append(e) options = { - 'reconnect_time_wait': 0.2, - 'servers': ["nats://hello:world@127.0.0.1:4223", ], - 'max_reconnect_attempts': 3, - 'error_cb': err_cb, + "reconnect_time_wait": 0.2, + "servers": [ + "nats://hello:world@127.0.0.1:4223", + ], + "max_reconnect_attempts": 3, + "error_cb": err_cb, } with self.assertRaises(nats.errors.NoServersError): await nats.connect(**options) @@ -1134,29 +1124,25 @@ async def err_cb(e): errors.append(e) options = { - 'dont_randomize': True, - 'reconnect_time_wait': 0.5, - 'disconnected_cb': disconnected_cb, - 'error_cb': err_cb, - 'servers': [ + "dont_randomize": True, + "reconnect_time_wait": 0.5, + "disconnected_cb": disconnected_cb, + "error_cb": err_cb, + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'max_reconnect_attempts': -1 + "max_reconnect_attempts": -1, } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) - self.assertTrue(nc._server_info['auth_required']) + self.assertIn("auth_required", nc._server_info) + self.assertTrue(nc._server_info["auth_required"]) self.assertTrue(nc.is_connected) # Stop all servers so that there aren't any available to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.2) @@ -1179,7 +1165,7 @@ async def err_cb(e): # Many attempts but only at most 2 reconnects would have occurred, # in case it was able to reconnect to another server while it was # shutting down. - self.assertTrue(nc.stats['reconnects'] >= 1) + self.assertTrue(nc.stats["reconnects"] >= 1) # Wrap off and disconnect await nc.close() @@ -1207,23 +1193,23 @@ async def err_cb(e): errors.append(e) options = { - 'dont_randomize': True, - 'reconnect_time_wait': 0.5, - 'disconnected_cb': disconnected_cb, - 'error_cb': err_cb, - 'closed_cb': closed_cb, - 'servers': [ + "dont_randomize": True, + "reconnect_time_wait": 0.5, + "disconnected_cb": disconnected_cb, + "error_cb": err_cb, + "closed_cb": closed_cb, + "servers": [ "nats://foo:bar@127.0.0.1:4223", "nats://hoge:fuga@127.0.0.1:4224", - "nats://hello:world@127.0.0.1:4225" + "nats://hello:world@127.0.0.1:4225", ], - 'max_reconnect_attempts': 3, - 'dont_randomize': True + "max_reconnect_attempts": 3, + "dont_randomize": True, } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) - self.assertTrue(nc._server_info['auth_required']) + self.assertIn("auth_required", nc._server_info) + self.assertTrue(nc._server_info["auth_required"]) self.assertTrue(nc.is_connected) # Check number of nodes in the server pool. @@ -1231,12 +1217,8 @@ async def err_cb(e): # Stop all servers so that there aren't any available to reconnect # then start one of them again. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.1) @@ -1258,16 +1240,14 @@ async def err_cb(e): await asyncio.sleep(0) # Stop the server once again - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[1].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[1].stop) for i in range(0, 10): await asyncio.sleep(0) await asyncio.sleep(0.1) await asyncio.sleep(0) # Only reconnected successfully once to the same server. - self.assertTrue(nc.stats['reconnects'] == 1) + self.assertTrue(nc.stats["reconnects"] == 1) self.assertEqual(1, len(nc._server_pool)) # await nc.close() @@ -1297,13 +1277,13 @@ async def err_cb(e): errors.append(e) options = { - 'dont_randomize': True, - 'reconnect_time_wait': 0.5, - 'disconnected_cb': disconnected_cb, - 'error_cb': err_cb, - 'closed_cb': closed_cb, - 'max_reconnect_attempts': 3, - 'dont_randomize': True, + "dont_randomize": True, + "reconnect_time_wait": 0.5, + "disconnected_cb": disconnected_cb, + "error_cb": err_cb, + "closed_cb": closed_cb, + "max_reconnect_attempts": 3, + "dont_randomize": True, } nc = await nats.connect("nats://foo:bar@127.0.0.1:4223", **options) @@ -1343,15 +1323,15 @@ async def closed_cb(): closed_count += 1 options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'dont_randomize': True, - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'reconnect_time_wait': 0.01 + "dont_randomize": True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "reconnect_time_wait": 0.01, } await nc.connect(**options) largest_pending_data_size = 0 @@ -1364,7 +1344,7 @@ async def cb(msg): await nc.subscribe("example.*", cb=cb) for i in range(0, 200): - await nc.publish(f"example.{i}", b'A' * 20) + await nc.publish(f"example.{i}", b"A" * 20) if nc.pending_data_size > 0: largest_pending_data_size = nc.pending_data_size if nc.pending_data_size > 100: @@ -1385,7 +1365,7 @@ async def cb(msg): await asyncio.sleep(0) await asyncio.sleep(0.2) await asyncio.sleep(0) - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) try: await nc.flush(2) except nats.errors.TimeoutError: @@ -1419,16 +1399,16 @@ async def closed_cb(): closed_count += 1 options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'dont_randomize': True, - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'flusher_queue_size': 100, - 'reconnect_time_wait': 0.01 + "dont_randomize": True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "flusher_queue_size": 100, + "reconnect_time_wait": 0.01, } await nc.connect(**options) largest_pending_data_size = 0 @@ -1441,7 +1421,7 @@ async def cb(msg): await nc.subscribe("example.*", cb=cb) for i in range(0, 500): - await nc.publish(f"example.{i}", b'A' * 20) + await nc.publish(f"example.{i}", b"A" * 20) if nc.pending_data_size > 0: largest_pending_data_size = nc.pending_data_size if nc.pending_data_size > 100: @@ -1462,7 +1442,7 @@ async def cb(msg): await asyncio.sleep(0) await asyncio.sleep(0.2) await asyncio.sleep(0) - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) try: await nc.flush(2) except nats.errors.TimeoutError: @@ -1500,15 +1480,15 @@ async def err_cb(e): errors.append(e) options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, - 'dont_randomize': True, + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, + "dont_randomize": True, } nc = await nats.connect(**options) self.assertTrue(nc.is_connected) @@ -1517,28 +1497,26 @@ async def worker_handler(msg): nonlocal counter counter += 1 if msg.reply != "": - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) await nc.subscribe("one", cb=worker_handler) await nc.subscribe("two", cb=worker_handler) await nc.subscribe("three", cb=worker_handler) - response = await nc.request("one", b'Help!', timeout=1) - self.assertEqual(b'Reply:1', response.data) + response = await nc.request("one", b"Help!", timeout=1) + self.assertEqual(b"Reply:1", response.data) # Stop the first server and connect to another one asap. - asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) # FIXME: Find better way to wait for the server to be stopped. await asyncio.sleep(0.5) - response = await nc.request("three", b'Help!', timeout=1) - self.assertEqual(b'Reply:2', response.data) + response = await nc.request("three", b"Help!", timeout=1) + self.assertEqual(b"Reply:2", response.data) await asyncio.sleep(0.5) await nc.close() - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) self.assertEqual(1, closed_count) self.assertEqual(2, disconnected_count) self.assertEqual(1, reconnected_count) @@ -1547,14 +1525,17 @@ async def worker_handler(msg): class ClientAuthTokenTest(MultiServerAuthTokenTestCase): - @async_test async def test_connect_with_auth_token(self): nc = NATS() - options = {'servers': ["nats://token@127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://token@127.0.0.1:4223", + ] + } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1565,11 +1546,13 @@ async def test_connect_with_auth_token_option(self): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4223", ], - 'token': "token", + "servers": [ + "nats://127.0.0.1:4223", + ], + "token": "token", } await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1580,16 +1563,18 @@ async def test_connect_with_bad_auth_token(self): nc = NATS() options = { - 'servers': ["nats://token@127.0.0.1:4225", ], - 'allow_reconnect': False, - 'reconnect_time_wait': 0.1, - 'max_reconnect_attempts': 1, + "servers": [ + "nats://token@127.0.0.1:4225", + ], + "allow_reconnect": False, + "reconnect_time_wait": 0.1, + "max_reconnect_attempts": 1, } # Authorization Violation with self.assertRaises(nats.errors.Error): await nc.connect(**options) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertFalse(nc.is_connected) @async_test @@ -1619,32 +1604,30 @@ async def worker_handler(msg): nonlocal counter counter += 1 if msg.reply != "": - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) options = { - 'servers': [ + "servers": [ "nats://token@127.0.0.1:4223", "nats://token@127.0.0.1:4224", ], - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'dont_randomize': True + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "dont_randomize": True, } await nc.connect(**options) await nc.subscribe("test", cb=worker_handler) - self.assertIn('auth_required', nc._server_info) + self.assertIn("auth_required", nc._server_info) self.assertTrue(nc.is_connected) # Trigger a reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await nc.subscribe("test", cb=worker_handler) - response = await nc.request("test", b'data', timeout=1) - self.assertEqual(b'Reply:1', response.data) + response = await nc.request("test", b"data", timeout=1) + self.assertEqual(b"Reply:1", response.data) await nc.close() self.assertTrue(nc.is_closed) @@ -1655,14 +1638,13 @@ async def worker_handler(msg): class ClientTLSTest(TLSServerTestCase): - @async_test async def test_connect(self): nc = NATS() - await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx) - self.assertEqual(nc._server_info['max_payload'], nc.max_payload) - self.assertTrue(nc._server_info['tls_required']) - self.assertTrue(nc._server_info['tls_verify']) + await nc.connect(servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx) + self.assertEqual(nc._server_info["max_payload"], nc.max_payload) + self.assertTrue(nc._server_info["tls_required"]) + self.assertTrue(nc._server_info["tls_verify"]) self.assertTrue(nc.max_payload > 0) self.assertTrue(nc.is_connected) await nc.close() @@ -1675,9 +1657,7 @@ async def test_default_connect_using_tls_scheme(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect( - servers=['tls://127.0.0.1:4224'], allow_reconnect=False - ) + await nc.connect(servers=["tls://127.0.0.1:4224"], allow_reconnect=False) @async_test async def test_default_connect_using_tls_scheme_in_url(self): @@ -1685,7 +1665,7 @@ async def test_default_connect_using_tls_scheme_in_url(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect('tls://127.0.0.1:4224', allow_reconnect=False) + await nc.connect("tls://127.0.0.1:4224", allow_reconnect=False) @async_test async def test_connect_tls_with_custom_hostname(self): @@ -1694,7 +1674,7 @@ async def test_connect_tls_with_custom_hostname(self): # Will attempt to connect using TLS with an invalid hostname. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['nats://127.0.0.1:4224'], + servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx, tls_hostname="nats.example", allow_reconnect=False, @@ -1708,32 +1688,30 @@ async def test_subscribe(self): async def subscription_handler(msg): msgs.append(msg) - payload = b'hello world' - await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx) + payload = b"hello world" + await nc.connect(servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx) sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") # Wait a bit for message to be received. await asyncio.sleep(0.2) self.assertEqual(1, len(msgs)) msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(1, sub._received) await nc.close() class ClientTLSReconnectTest(MultiTLSServerAuthTestCase): - @async_test async def test_tls_reconnect(self): - nc = NATS() disconnected_count = 0 reconnected_count = 0 @@ -1762,41 +1740,39 @@ async def worker_handler(msg): nonlocal counter counter += 1 if msg.reply != "": - await nc.publish(msg.reply, f'Reply:{counter}'.encode()) + await nc.publish(msg.reply, f"Reply:{counter}".encode()) options = { - 'servers': [ + "servers": [ "nats://foo:bar@127.0.0.1:4223", - "nats://hoge:fuga@127.0.0.1:4224" + "nats://hoge:fuga@127.0.0.1:4224", ], - 'disconnected_cb': disconnected_cb, - 'closed_cb': closed_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, - 'dont_randomize': True, - 'tls': self.ssl_ctx + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, + "dont_randomize": True, + "tls": self.ssl_ctx, } await nc.connect(**options) self.assertTrue(nc.is_connected) await nc.subscribe("example", cb=worker_handler) - response = await nc.request("example", b'Help!', timeout=1) - self.assertEqual(b'Reply:1', response.data) + response = await nc.request("example", b"Help!", timeout=1) + self.assertEqual(b"Reply:1", response.data) # Trigger a reconnect and should be fine - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await nc.subscribe("example", cb=worker_handler) - response = await nc.request("example", b'Help!', timeout=1) - self.assertEqual(b'Reply:2', response.data) + response = await nc.request("example", b"Help!", timeout=1) + self.assertEqual(b"Reply:2", response.data) await nc.close() self.assertTrue(nc.is_closed) self.assertFalse(nc.is_connected) - self.assertEqual(1, nc.stats['reconnects']) + self.assertEqual(1, nc.stats["reconnects"]) self.assertEqual(1, closed_count) self.assertEqual(2, disconnected_count) self.assertEqual(1, reconnected_count) @@ -1804,20 +1780,17 @@ async def worker_handler(msg): class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase): - @async_test async def test_connect(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = await nats.connect( - 'nats://127.0.0.1:4224', - tls=self.ssl_ctx, - tls_handshake_first=True + "nats://127.0.0.1:4224", tls=self.ssl_ctx, tls_handshake_first=True ) - self.assertEqual(nc._server_info['max_payload'], nc.max_payload) - self.assertTrue(nc._server_info['tls_required']) - self.assertTrue(nc._server_info['tls_verify']) + self.assertEqual(nc._server_info["max_payload"], nc.max_payload) + self.assertTrue(nc._server_info["tls_required"]) + self.assertTrue(nc._server_info["tls_verify"]) self.assertTrue(nc.max_payload > 0) self.assertTrue(nc.is_connected) await nc.close() @@ -1826,7 +1799,7 @@ async def test_connect(self): @async_test async def test_default_connect_using_tls_scheme(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1834,14 +1807,14 @@ async def test_default_connect_using_tls_scheme(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['tls://127.0.0.1:4224'], + servers=["tls://127.0.0.1:4224"], allow_reconnect=False, tls_handshake_first=True, ) @async_test async def test_default_connect_using_tls_scheme_in_url(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1849,14 +1822,12 @@ async def test_default_connect_using_tls_scheme_in_url(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - 'tls://127.0.0.1:4224', - allow_reconnect=False, - tls_handshake_first=True + "tls://127.0.0.1:4224", allow_reconnect=False, tls_handshake_first=True ) @async_test async def test_connect_tls_with_custom_hostname(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1864,7 +1835,7 @@ async def test_connect_tls_with_custom_hostname(self): # Will attempt to connect using TLS with an invalid hostname. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['nats://127.0.0.1:4224'], + servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx, tls_hostname="nats.example", tls_handshake_first=True, @@ -1873,7 +1844,7 @@ async def test_connect_tls_with_custom_hostname(self): @async_test async def test_subscribe(self): - if os.environ.get('NATS_SERVER_VERSION') != 'main': + if os.environ.get("NATS_SERVER_VERSION") != "main": pytest.skip("test requires nats-server@main") nc = NATS() @@ -1882,33 +1853,32 @@ async def test_subscribe(self): async def subscription_handler(msg): msgs.append(msg) - payload = b'hello world' + payload = b"hello world" await nc.connect( - servers=['nats://127.0.0.1:4224'], + servers=["nats://127.0.0.1:4224"], tls=self.ssl_ctx, - tls_handshake_first=True + tls_handshake_first=True, ) sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) with self.assertRaises(nats.errors.BadSubjectError): - await nc.publish("", b'') + await nc.publish("", b"") # Wait a bit for message to be received. await asyncio.sleep(0.2) self.assertEqual(1, len(msgs)) msg = msgs[0] - self.assertEqual('foo', msg.subject) - self.assertEqual('', msg.reply) + self.assertEqual("foo", msg.subject) + self.assertEqual("", msg.reply) self.assertEqual(payload, msg.data) self.assertEqual(1, sub._received) await nc.close() class ClusterDiscoveryTest(ClusteringTestCase): - @async_test async def test_discover_servers_on_first_connect(self): nc = NATS() @@ -1924,14 +1894,16 @@ async def test_discover_servers_on_first_connect(self): ) await asyncio.sleep(1) - options = {'servers': ["nats://127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://127.0.0.1:4223", + ] + } discovered_server_cb = mock.Mock() - with mock.patch('asyncio.iscoroutinefunction', return_value=True): - await nc.connect( - **options, discovered_server_cb=discovered_server_cb - ) + with mock.patch("asyncio.iscoroutinefunction", return_value=True): + await nc.connect(**options, discovered_server_cb=discovered_server_cb) self.assertTrue(nc.is_connected) await nc.close() self.assertTrue(nc.is_closed) @@ -1943,12 +1915,14 @@ async def test_discover_servers_on_first_connect(self): async def test_discover_servers_after_first_connect(self): nc = NATS() - options = {'servers': ["nats://127.0.0.1:4223", ]} + options = { + "servers": [ + "nats://127.0.0.1:4223", + ] + } discovered_server_cb = mock.Mock() - with mock.patch('asyncio.iscoroutinefunction', return_value=True): - await nc.connect( - **options, discovered_server_cb=discovered_server_cb - ) + with mock.patch("asyncio.iscoroutinefunction", return_value=True): + await nc.connect(**options, discovered_server_cb=discovered_server_cb) # Start rest of cluster members so that we receive them # connect_urls on the first connect. @@ -1969,7 +1943,6 @@ async def test_discover_servers_after_first_connect(self): class ClusterDiscoveryReconnectTest(ClusteringDiscoveryAuthTestCase): - @async_test async def test_reconnect_to_new_server_with_auth(self): nc = NATS() @@ -1985,12 +1958,14 @@ async def err_cb(e): errors.append(e) options = { - 'servers': ["nats://127.0.0.1:4223", ], - 'reconnected_cb': reconnected_cb, - 'error_cb': err_cb, - 'reconnect_time_wait': 0.1, - 'user': "foo", - 'password': "bar", + "servers": [ + "nats://127.0.0.1:4223", + ], + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, + "reconnect_time_wait": 0.1, + "user": "foo", + "password": "bar", } await nc.connect(**options) @@ -1998,18 +1973,16 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(reconnected, 2) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) await nc.close() self.assertTrue(nc.is_closed) @@ -2058,25 +2031,23 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) # Publishing while disconnected is an error if pending size is disabled. with self.assertRaises(nats.errors.OutboundBufferLimitError): - await nc.request("foo", b'hi') + await nc.request("foo", b"hi") with self.assertRaises(nats.errors.OutboundBufferLimitError): - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") await asyncio.wait_for(reconnected, 2) @@ -2128,21 +2099,19 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Remove first member and try to reconnect - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) # While reconnecting the pending data will accumulate. - await nc.publish("foo", b'bar') + await nc.publish("foo", b"bar") self.assertEqual(nc._pending_data_size, 17) # Publishing while disconnected is an error if it hits the pending size. @@ -2197,12 +2166,12 @@ async def err_cb(e): await asyncio.sleep(1) async def handler(msg): - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Publishing while connected should trigger a force flush. payload = ("A" * 1025).encode() @@ -2255,12 +2224,12 @@ async def err_cb(e): async def handler(msg): # This becomes an async error if msg.reply: - await nc.publish(msg.reply, b'ok') + await nc.publish(msg.reply, b"ok") await nc.subscribe("foo", cb=handler) - msg = await nc.request("foo", b'hi') - self.assertEqual(b'ok', msg.data) + msg = await nc.request("foo", b"hi") + self.assertEqual(b"ok", msg.data) # Publishing while connected should trigger a force flush. payload = ("A" * 1025).encode() @@ -2281,19 +2250,17 @@ async def handler(msg): class ConnectFailuresTest(SingleServerTestCase): - @async_test async def test_empty_info_op_uses_defaults(self): - async def bad_server(reader, writer): - writer.write(b'INFO {}\r\n') + writer.write(b"INFO {}\r\n") await writer.drain() data = await reader.readline() await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) disconnected_count = 0 @@ -2303,8 +2270,10 @@ async def disconnected_cb(): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'disconnected_cb': disconnected_cb + "servers": [ + "nats://127.0.0.1:4555", + ], + "disconnected_cb": disconnected_cb, } await nc.connect(**options) self.assertEqual(nc.max_payload, 1048576) @@ -2314,13 +2283,12 @@ async def disconnected_cb(): @async_test async def test_empty_response_from_server(self): - async def bad_server(reader, writer): - writer.write(b'') + writer.write(b"") await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) errors = [] @@ -2330,9 +2298,11 @@ async def error_cb(e): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'error_cb': error_cb, - 'allow_reconnect': False, + "servers": [ + "nats://127.0.0.1:4555", + ], + "error_cb": error_cb, + "allow_reconnect": False, } with self.assertRaises(nats.errors.Error): @@ -2342,13 +2312,12 @@ async def error_cb(e): @async_test async def test_malformed_info_response_from_server(self): - async def bad_server(reader, writer): - writer.write(b'INF') + writer.write(b"INF") await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) errors = [] @@ -2358,9 +2327,11 @@ async def error_cb(e): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'error_cb': error_cb, - 'allow_reconnect': False, + "servers": [ + "nats://127.0.0.1:4555", + ], + "error_cb": error_cb, + "allow_reconnect": False, } with self.assertRaises(nats.errors.Error): @@ -2370,13 +2341,12 @@ async def error_cb(e): @async_test async def test_malformed_info_json_response_from_server(self): - async def bad_server(reader, writer): - writer.write(b'INFO {\r\n') + writer.write(b"INFO {\r\n") await asyncio.sleep(0.2) writer.close() - await asyncio.start_server(bad_server, '127.0.0.1', 4555) + await asyncio.start_server(bad_server, "127.0.0.1", 4555) errors = [] @@ -2386,9 +2356,11 @@ async def error_cb(e): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'error_cb': error_cb, - 'allow_reconnect': False, + "servers": [ + "nats://127.0.0.1:4555", + ], + "error_cb": error_cb, + "allow_reconnect": False, } with self.assertRaises(nats.errors.Error): @@ -2399,12 +2371,11 @@ async def error_cb(e): @async_test async def test_connect_timeout(self): - async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() - await asyncio.start_server(slow_server, '127.0.0.1', 4555) + await asyncio.start_server(slow_server, "127.0.0.1", 4555) disconnected_count = 0 reconnected = asyncio.Future() @@ -2419,12 +2390,14 @@ async def reconnected_cb(): nc = NATS() options = { - 'servers': ["nats://127.0.0.1:4555", ], - 'disconnected_cb': disconnected_cb, - 'reconnected_cb': reconnected_cb, - 'connect_timeout': 0.5, - 'dont_randomize': True, - 'allow_reconnect': False, + "servers": [ + "nats://127.0.0.1:4555", + ], + "disconnected_cb": disconnected_cb, + "reconnected_cb": reconnected_cb, + "connect_timeout": 0.5, + "dont_randomize": True, + "allow_reconnect": False, } with self.assertRaises(asyncio.TimeoutError): @@ -2436,12 +2409,11 @@ async def reconnected_cb(): @async_test async def test_connect_timeout_then_connect_to_healthy_server(self): - async def slow_server(reader, writer): await asyncio.sleep(1) writer.close() - await asyncio.start_server(slow_server, '127.0.0.1', 4555) + await asyncio.start_server(slow_server, "127.0.0.1", 4555) disconnected_count = 0 reconnected = asyncio.Future() @@ -2462,15 +2434,15 @@ async def error_cb(e): nc = NATS() options = { - 'servers': [ + "servers": [ "nats://127.0.0.1:4555", "nats://127.0.0.1:4222", ], - 'disconnected_cb': disconnected_cb, - 'reconnected_cb': reconnected_cb, - 'error_cb': error_cb, - 'connect_timeout': 0.5, - 'dont_randomize': True, + "disconnected_cb": disconnected_cb, + "reconnected_cb": reconnected_cb, + "error_cb": error_cb, + "connect_timeout": 0.5, + "dont_randomize": True, } await nc.connect(**options) @@ -2479,7 +2451,7 @@ async def error_cb(e): self.assertTrue(nc.is_connected) for i in range(0, 10): - await nc.publish("foo", b'ok ok') + await nc.publish("foo", b"ok ok") await nc.flush() await nc.close() @@ -2490,7 +2462,6 @@ async def error_cb(e): class ClientDrainTest(SingleServerTestCase): - @async_test async def test_drain_subscription(self): nc = NATS() @@ -2528,7 +2499,7 @@ async def handler(msg): sub = await nc.subscribe("foo", cb=handler) for i in range(0, 200): - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") # Relinquish control so that messages are processed. await asyncio.sleep(0) @@ -2542,7 +2513,7 @@ async def handler(msg): await asyncio.wait_for(drain_task, 1) for i in range(0, 200): - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") # Relinquish control so that messages are processed. await asyncio.sleep(0) @@ -2564,7 +2535,7 @@ async def test_drain_subscription_with_future(self): fut = asyncio.Future() sub = await nc.subscribe("foo", future=fut) - await nc.publish("foo", b'hi') + await nc.publish("foo", b"hi") await nc.flush() drain_task = sub.drain() @@ -2599,7 +2570,7 @@ async def closed_cb(): async def handler(msg): if len(msgs) % 20 == 1: await asyncio.sleep(0.2) - await nc.publish(msg.reply, b'OK!', reply=msg.subject) + await nc.publish(msg.reply, b"OK!", reply=msg.subject) await nc2.flush() sub_foo = await nc.subscribe("foo", cb=handler) @@ -2613,15 +2584,13 @@ async def replies(msg): await nc2.subscribe("my-replies.*", cb=replies) for i in range(0, 201): await nc2.publish( - "foo", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "bar", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "quux", - b'help', - reply=f"my-replies.{nc._nuid.next().decode()}" + "quux", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) # Relinquish control so that messages are processed. @@ -2681,9 +2650,7 @@ async def closed_cb(): nonlocal drain_done drain_done.set_result(True) - await nc.connect( - closed_cb=closed_cb, error_cb=error_cb, drain_timeout=0.1 - ) + await nc.connect(closed_cb=closed_cb, error_cb=error_cb, drain_timeout=0.1) nc2 = NATS() await nc2.connect() @@ -2695,7 +2662,7 @@ async def handler(msg): await asyncio.sleep(0.2) if len(msgs) % 50 == 1: await asyncio.sleep(0.5) - await nc.publish(msg.reply, b'OK!', reply=msg.subject) + await nc.publish(msg.reply, b"OK!", reply=msg.subject) await nc2.flush() await nc.subscribe("foo", cb=handler) @@ -2709,15 +2676,13 @@ async def replies(msg): await nc2.subscribe("my-replies.*", cb=replies) for i in range(0, 201): await nc2.publish( - "foo", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "foo", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "bar", b'help', reply=f"my-replies.{nc._nuid.next().decode()}" + "bar", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) await nc2.publish( - "quux", - b'help', - reply=f"my-replies.{nc._nuid.next().decode()}" + "quux", b"help", reply=f"my-replies.{nc._nuid.next().decode()}" ) # Relinquish control so that messages are processed. @@ -2744,32 +2709,30 @@ async def test_non_async_callbacks_raise_error(self): def f(): pass - for cb in ['error_cb', 'disconnected_cb', 'discovered_server_cb', - 'closed_cb', 'reconnected_cb']: - + for cb in [ + "error_cb", + "disconnected_cb", + "discovered_server_cb", + "closed_cb", + "reconnected_cb", + ]: with self.assertRaises(nats.errors.InvalidCallbackTypeError): await nc.connect( servers=["nats://127.0.0.1:4222"], max_reconnect_attempts=2, reconnect_time_wait=0.2, - **{cb: f} + **{cb: f}, ) @async_test async def test_protocol_mixing(self): nc = NATS() with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["nats://127.0.0.1:4222", "ws://127.0.0.1:8080"] - ) + await nc.connect(servers=["nats://127.0.0.1:4222", "ws://127.0.0.1:8080"]) with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["nats://127.0.0.1:4222", "wss://127.0.0.1:8080"] - ) + await nc.connect(servers=["nats://127.0.0.1:4222", "wss://127.0.0.1:8080"]) with self.assertRaises(nats.errors.Error): - await nc.connect( - servers=["tls://127.0.0.1:4222", "wss://127.0.0.1:8080"] - ) + await nc.connect(servers=["tls://127.0.0.1:4222", "wss://127.0.0.1:8080"]) @async_test async def test_drain_cancelled_errors_raised(self): @@ -2790,15 +2753,14 @@ async def cb(msg): await asyncio.sleep(0.1) with self.assertRaises(asyncio.CancelledError): with unittest.mock.patch( - "asyncio.wait_for", - unittest.mock.AsyncMock(side_effect=asyncio.CancelledError - )): + "asyncio.wait_for", + unittest.mock.AsyncMock(side_effect=asyncio.CancelledError), + ): await sub.drain() await nc.close() class NoAuthUserClientTest(NoAuthUserServerTestCase): - @async_test async def test_connect_user(self): fut = asyncio.Future() @@ -2816,22 +2778,22 @@ async def err_cb(e): sub = await nc.subscribe("foo") await nc.flush() await asyncio.sleep(0) - await nc.publish("foo", b'hello') + await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() - assert str( - err - ) == 'nats: permissions violation for subscription to "foo"' + assert str(err) == 'nats: permissions violation for subscription to "foo"' - nc2 = await nats.connect("nats://127.0.0.1:4555", ) + nc2 = await nats.connect( + "nats://127.0.0.1:4555", + ) async def cb(msg): - await msg.respond(b'pong') + await msg.respond(b"pong") sub2 = await nc2.subscribe("foo", cb=cb) await nc2.flush() - resp = await nc2.request("foo", b'ping') - assert resp.data == b'pong' + resp = await nc2.request("foo", b"ping") + assert resp.data == b"pong" await nc2.close() await nc.close() @@ -2851,29 +2813,28 @@ async def err_cb(e): sub = await nc.subscribe("foo") await nc.flush() await asyncio.sleep(0) - await nc.publish("foo", b'hello') + await nc.publish("foo", b"hello") await asyncio.wait_for(fut, 2) err = fut.result() - assert str( - err - ) == 'nats: permissions violation for subscription to "foo"' + assert str(err) == 'nats: permissions violation for subscription to "foo"' - nc2 = await nats.connect("nats://127.0.0.1:4555", ) + nc2 = await nats.connect( + "nats://127.0.0.1:4555", + ) async def cb(msg): - await msg.respond(b'pong') + await msg.respond(b"pong") sub2 = await nc2.subscribe("foo", cb=cb) await nc2.flush() - resp = await nc2.request("foo", b'ping') - assert resp.data == b'pong' + resp = await nc2.request("foo", b"ping") + assert resp.data == b"pong" await nc2.close() await nc.close() class ClientDisconnectTest(SingleServerTestCase): - @async_test async def test_close_while_disconnected(self): reconnected = asyncio.Future() @@ -2900,14 +2861,12 @@ async def disconnected_cb(): error_cb=error_cb, ) sub = await nc.subscribe("foo") - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') + self.assertEqual(msg.data, b"First") - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.wait_for(disconnected, 2) await nc.close() @@ -2915,7 +2874,8 @@ async def disconnected_cb(): disconnected_states[1] == NATS.CLOSED -if __name__ == '__main__': +if __name__ == "__main__": import sys + runner = unittest.TextTestRunner(stream=sys.stdout) unittest.main(verbosity=2, exit=False, testRunner=runner) diff --git a/tests/test_client_async_await.py b/tests/test_client_async_await.py index 05da8bd6..88c9c205 100644 --- a/tests/test_client_async_await.py +++ b/tests/test_client_async_await.py @@ -9,7 +9,6 @@ class ClientAsyncAwaitTest(SingleServerTestCase): - @async_test async def test_async_await_subscribe_async(self): nc = NATS() @@ -27,7 +26,7 @@ async def subscription_handler(msg): self.assertIsNotNone(sub) for i in range(0, 5): - await nc.publish(f"tests.{i}", b'bar') + await nc.publish(f"tests.{i}", b"bar") # Wait a bit for messages to be received. await asyncio.sleep(1) @@ -60,34 +59,34 @@ async def handler_foo(msg): async def handler_bar(msg): msgs.append(msg) if msg.reply != "": - await nc.publish(msg.reply, b'') + await nc.publish(msg.reply, b"") await nc.subscribe("bar", cb=handler_bar) - await nc.publish("foo", b'1') - await nc.publish("foo", b'2') - await nc.publish("foo", b'3') + await nc.publish("foo", b"1") + await nc.publish("foo", b"2") + await nc.publish("foo", b"3") # Will be processed before the others since no head of line # blocking among the subscriptions. - await nc.publish("bar", b'4') + await nc.publish("bar", b"4") - response = await nc.request("foo", b'hello1', timeout=1) - self.assertEqual(response.data, b'hello1hello1') + response = await nc.request("foo", b"hello1", timeout=1) + self.assertEqual(response.data, b"hello1hello1") with self.assertRaises(TimeoutError): - await nc.request("foo", b'hello2', timeout=0.1) - - await nc.publish("bar", b'5') - response = await nc.request("foo", b'hello2', timeout=1) - self.assertEqual(response.data, b'hello2hello2') - - self.assertEqual(msgs[0].data, b'1') - self.assertEqual(msgs[1].data, b'4') - self.assertEqual(msgs[2].data, b'2') - self.assertEqual(msgs[3].data, b'3') - self.assertEqual(msgs[4].data, b'hello1') - self.assertEqual(msgs[5].data, b'hello2') + await nc.request("foo", b"hello2", timeout=0.1) + + await nc.publish("bar", b"5") + response = await nc.request("foo", b"hello2", timeout=1) + self.assertEqual(response.data, b"hello2hello2") + + self.assertEqual(msgs[0].data, b"1") + self.assertEqual(msgs[1].data, b"4") + self.assertEqual(msgs[2].data, b"2") + self.assertEqual(msgs[3].data, b"3") + self.assertEqual(msgs[4].data, b"hello1") + self.assertEqual(msgs[5].data, b"hello2") self.assertEqual(len(errors), 0) await nc.close() @@ -120,17 +119,17 @@ async def handler_bar(msg): await nc.subscribe("bar", cb=handler_bar) for i in range(10): - await nc.publish("foo", f'{i}'.encode()) + await nc.publish("foo", f"{i}".encode()) # Will be processed before the others since no head of line # blocking among the subscriptions. - await nc.publish("bar", b'14') - response = await nc.request("bar", b'hi1', 2) - self.assertEqual(response.data, b'hi1hi1hi1') + await nc.publish("bar", b"14") + response = await nc.request("bar", b"hi1", 2) + self.assertEqual(response.data, b"hi1hi1hi1") self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0].data, b'14') - self.assertEqual(msgs[1].data, b'hi1') + self.assertEqual(msgs[0].data, b"14") + self.assertEqual(msgs[1].data, b"hi1") # Consumed messages but the rest were slow consumers. self.assertTrue(4 <= len(errors) <= 5) @@ -174,13 +173,13 @@ async def handler_bar(msg): # Will be processed before the others since no head of line # blocking among the subscriptions. - await nc.publish("bar", b'14') + await nc.publish("bar", b"14") - response = await nc.request("bar", b'hi1', 2) - self.assertEqual(response.data, b'hi1hi1hi1') + response = await nc.request("bar", b"hi1", 2) + self.assertEqual(response.data, b"hi1hi1hi1") self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0].data, b'14') - self.assertEqual(msgs[1].data, b'hi1') + self.assertEqual(msgs[0].data, b"14") + self.assertEqual(msgs[1].data, b"hi1") # Consumed a few messages but the rest were slow consumers. self.assertTrue(7 <= len(errors) <= 8) @@ -191,6 +190,6 @@ async def handler_bar(msg): self.assertEqual(errors[0].reply, "") # Try again a few seconds later and it should have recovered await asyncio.sleep(3) - response = await nc.request("foo", b'B', 1) - self.assertEqual(response.data, b'BB') + response = await nc.request("foo", b"B", 1) + self.assertEqual(response.data, b"BB") await nc.close() diff --git a/tests/test_client_nkeys.py b/tests/test_client_nkeys.py index 6161bd0f..c35e104c 100644 --- a/tests/test_client_nkeys.py +++ b/tests/test_client_nkeys.py @@ -6,6 +6,7 @@ try: import nkeys + nkeys_installed = True except ModuleNotFoundError: nkeys_installed = False @@ -23,23 +24,18 @@ class ClientNkeysAuthTest(NkeysServerTestCase): - @async_test async def test_nkeys_connect(self): import os config_file = get_config_file("nkeys/foo-user.nk") seed = None - with open(config_file, 'rb') as f: + with open(config_file, "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) args_list = [ - { - 'nkeys_seed': config_file - }, - { - 'nkeys_seed_str': seed.decode() - }, + {"nkeys_seed": config_file}, + {"nkeys_seed_str": seed.decode()}, ] for nkeys_args in args_list: if not nkeys_installed: @@ -53,33 +49,34 @@ async def error_cb(e): nonlocal future future.set_result(True) - await nc.connect(["tls://127.0.0.1:4222"], - error_cb=error_cb, - connect_timeout=10, - allow_reconnect=False, - **nkeys_args) + await nc.connect( + ["tls://127.0.0.1:4222"], + error_cb=error_cb, + connect_timeout=10, + allow_reconnect=False, + **nkeys_args, + ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.subscribe("bar", cb=help_handler) await nc.flush() await asyncio.wait_for(future, 1) - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() class ClientJWTAuthTest(TrustedServerTestCase): - @async_test async def test_nkeys_jwt_creds_user_connect(self): if not nkeys_installed: @@ -99,12 +96,12 @@ async def error_cb(e): ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() @async_test @@ -123,18 +120,18 @@ async def error_cb(e): connect_timeout=5, user_credentials=( get_config_file("nkeys/foo-user.jwt"), - get_config_file("nkeys/foo-user.nk") + get_config_file("nkeys/foo-user.nk"), ), allow_reconnect=False, ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() @async_test @@ -197,10 +194,10 @@ async def error_cb(e): ) async def help_handler(msg): - await nc.publish(msg.reply, b'OK!') + await nc.publish(msg.reply, b"OK!") await nc.subscribe("help", cb=help_handler) await nc.flush() - msg = await nc.request("help", b'I need help') - self.assertEqual(msg.data, b'OK!') + msg = await nc.request("help", b"I need help") + self.assertEqual(msg.data, b"OK!") await nc.close() diff --git a/tests/test_client_v2.py b/tests/test_client_v2.py index b78765ae..80bf96bd 100644 --- a/tests/test_client_v2.py +++ b/tests/test_client_v2.py @@ -15,7 +15,6 @@ class HeadersTest(SingleServerTestCase): - @async_test async def test_simple_headers(self): nc = await nats.connect() @@ -23,18 +22,15 @@ async def test_simple_headers(self): sub = await nc.subscribe("foo") await nc.flush() await nc.publish( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world-1' - } + "foo", b"hello world", headers={"foo": "bar", "hello": "world-1"} ) msg = await sub.next_msg() self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 2) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world-1') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world-1") await nc.close() @@ -44,24 +40,21 @@ async def test_request_with_headers(self): async def service(msg): # Add another header - msg.headers['quux'] = 'quuz' - await msg.respond(b'OK!') + msg.headers["quux"] = "quuz" + await msg.respond(b"OK!") await nc.subscribe("foo", cb=service) await nc.flush() msg = await nc.request( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world' - } + "foo", b"hello world", headers={"foo": "bar", "hello": "world"} ) self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 3) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world') - self.assertEqual(msg.headers['quux'], 'quuz') - self.assertEqual(msg.data, b'OK!') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world") + self.assertEqual(msg.headers["quux"], "quuz") + self.assertEqual(msg.data, b"OK!") await nc.close() @@ -71,38 +64,37 @@ async def test_empty_headers(self): sub = await nc.subscribe("foo") await nc.flush() - await nc.publish("foo", b'hello world', headers={'': ''}) + await nc.publish("foo", b"hello world", headers={"": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key - await nc.publish("foo", b'hello world', headers={' ': ''}) + await nc.publish("foo", b"hello world", headers={" ": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key - await nc.publish( - "foo", b'hello world', headers={'': ' '} - ) + await nc.publish("foo", b"hello world", headers={"": " "}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) hdrs = { - 'timestamp': '2022-06-15T19:08:14.639020', - 'type': 'rpc', - 'command': 'publish_state', - 'trace_id': '', - 'span_id': '' + "timestamp": "2022-06-15T19:08:14.639020", + "type": "rpc", + "command": "publish_state", + "trace_id": "", + "span_id": "", } - await nc.publish("foo", b'Hello from Python!', headers=hdrs) + await nc.publish("foo", b"Hello from Python!", headers=hdrs) msg = await sub.next_msg() self.assertEqual(msg.headers, hdrs) await nc.close() -if __name__ == '__main__': +if __name__ == "__main__": import sys + runner = unittest.TextTestRunner(stream=sys.stdout) unittest.main(verbosity=2, exit=False, testRunner=runner) diff --git a/tests/test_client_websocket.py b/tests/test_client_websocket.py index 097ce0f7..6d1256b2 100644 --- a/tests/test_client_websocket.py +++ b/tests/test_client_websocket.py @@ -16,13 +16,13 @@ try: import aiohttp + aiohttp_installed = True except ModuleNotFoundError: aiohttp_installed = False class WebSocketTest(SingleWebSocketServerTestCase): - @async_test async def test_simple_headers(self): if not aiohttp_installed: @@ -33,18 +33,15 @@ async def test_simple_headers(self): sub = await nc.subscribe("foo") await nc.flush() await nc.publish( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world-1' - } + "foo", b"hello world", headers={"foo": "bar", "hello": "world-1"} ) msg = await sub.next_msg() self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 2) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world-1') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world-1") await nc.close() @@ -57,24 +54,21 @@ async def test_request_with_headers(self): async def service(msg): # Add another header - msg.headers['quux'] = 'quuz' - await msg.respond(b'OK!') + msg.headers["quux"] = "quuz" + await msg.respond(b"OK!") await nc.subscribe("foo", cb=service) await nc.flush() msg = await nc.request( - "foo", b'hello world', headers={ - 'foo': 'bar', - 'hello': 'world' - } + "foo", b"hello world", headers={"foo": "bar", "hello": "world"} ) self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 3) - self.assertEqual(msg.headers['foo'], 'bar') - self.assertEqual(msg.headers['hello'], 'world') - self.assertEqual(msg.headers['quux'], 'quuz') - self.assertEqual(msg.data, b'OK!') + self.assertEqual(msg.headers["foo"], "bar") + self.assertEqual(msg.headers["hello"], "world") + self.assertEqual(msg.headers["quux"], "quuz") + self.assertEqual(msg.data, b"OK!") await nc.close() @@ -87,31 +81,29 @@ async def test_empty_headers(self): sub = await nc.subscribe("foo") await nc.flush() - await nc.publish("foo", b'hello world', headers={'': ''}) + await nc.publish("foo", b"hello world", headers={"": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key - await nc.publish("foo", b'hello world', headers={' ': ''}) + await nc.publish("foo", b"hello world", headers={" ": ""}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) # Empty long key - await nc.publish( - "foo", b'hello world', headers={'': ' '} - ) + await nc.publish("foo", b"hello world", headers={"": " "}) msg = await sub.next_msg() self.assertTrue(msg.headers == None) hdrs = { - 'timestamp': '2022-06-15T19:08:14.639020', - 'type': 'rpc', - 'command': 'publish_state', - 'trace_id': '', - 'span_id': '' + "timestamp": "2022-06-15T19:08:14.639020", + "type": "rpc", + "command": "publish_state", + "trace_id": "", + "span_id": "", } - await nc.publish("foo", b'Hello from Python!', headers=hdrs) + await nc.publish("foo", b"Hello from Python!", headers=hdrs) msg = await sub.next_msg() self.assertEqual(msg.headers, hdrs) @@ -136,21 +128,19 @@ async def reconnected_cb(): sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') + self.assertEqual(msg.data, b"First") - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await asyncio.get_running_loop().run_in_executor( None, self.server_pool[0].start @@ -158,12 +148,12 @@ async def bar_cb(msg): await asyncio.wait_for(reconnected, 2) # Get another message. - await nc.publish("foo", b'Second') + await nc.publish("foo", b"Second") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'Second') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"Second") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") await nc.close() @@ -187,20 +177,18 @@ async def reconnected_cb(): sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"First") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) # Should not fail closing while disconnected. @@ -208,7 +196,6 @@ async def bar_cb(msg): class WebSocketTLSTest(SingleWebSocketTLSServerTestCase): - @async_test async def test_pub_sub(self): if not aiohttp_installed: @@ -218,13 +205,13 @@ async def test_pub_sub(self): sub = await nc.subscribe("foo") await nc.flush() - await nc.publish("foo", b'hello world', headers={'foo': 'bar'}) + await nc.publish("foo", b"hello world", headers={"foo": "bar"}) msg = await sub.next_msg() self.assertTrue(msg.headers != None) self.assertEqual(len(msg.headers), 1) - self.assertEqual(msg.headers['foo'], 'bar') + self.assertEqual(msg.headers["foo"], "bar") await nc.close() @@ -240,29 +227,25 @@ async def reconnected_cb(): reconnected.set_result(True) nc = await nats.connect( - "wss://localhost:8081", - reconnected_cb=reconnected_cb, - tls=self.ssl_ctx + "wss://localhost:8081", reconnected_cb=reconnected_cb, tls=self.ssl_ctx ) sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') + self.assertEqual(msg.data, b"First") - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) await asyncio.get_running_loop().run_in_executor( None, self.server_pool[0].start @@ -270,12 +253,12 @@ async def bar_cb(msg): await asyncio.wait_for(reconnected, 2) # Get another message. - await nc.publish("foo", b'Second') + await nc.publish("foo", b"Second") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'Second') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"Second") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") await nc.close() @@ -300,27 +283,26 @@ async def reconnected_cb(): sub = await nc.subscribe("foo") async def bar_cb(msg): - await msg.respond(b'OK!') + await msg.respond(b"OK!") rsub = await nc.subscribe("bar", cb=bar_cb) - await nc.publish("foo", b'First') + await nc.publish("foo", b"First") await nc.flush() msg = await sub.next_msg() - self.assertEqual(msg.data, b'First') - rmsg = await nc.request("bar", b'hi') - self.assertEqual(rmsg.data, b'OK!') + self.assertEqual(msg.data, b"First") + rmsg = await nc.request("bar", b"hi") + self.assertEqual(rmsg.data, b"OK!") # Restart the server and wait for reconnect. - await asyncio.get_running_loop().run_in_executor( - None, self.server_pool[0].stop - ) + await asyncio.get_running_loop().run_in_executor(None, self.server_pool[0].stop) await asyncio.sleep(1) # Should not fail closing while disconnected. await nc.close() -if __name__ == '__main__': +if __name__ == "__main__": import sys + runner = unittest.TextTestRunner(stream=sys.stdout) unittest.main(verbosity=2, exit=False, testRunner=runner) diff --git a/tests/test_js.py b/tests/test_js.py index f6d632be..f7babfce 100644 --- a/tests/test_js.py +++ b/tests/test_js.py @@ -31,7 +31,6 @@ class PublishTest(SingleJetStreamServerTestCase): - @async_test async def test_publish(self): nc = NATS() @@ -39,20 +38,20 @@ async def test_publish(self): js = nc.jetstream() with pytest.raises(NoStreamResponseError): - await js.publish("foo", b'bar') + await js.publish("foo", b"bar") await js.add_stream(name="QUUX", subjects=["quux"]) - ack = await js.publish("quux", b'bar:1', stream="QUUX") + ack = await js.publish("quux", b"bar:1", stream="QUUX") assert ack.stream == "QUUX" assert ack.seq == 1 - ack = await js.publish("quux", b'bar:2') + ack = await js.publish("quux", b"bar:2") assert ack.stream == "QUUX" assert ack.seq == 2 with pytest.raises(BadRequestError) as err: - await js.publish("quux", b'bar', stream="BAR") + await js.publish("quux", b"bar", stream="BAR") assert err.value.err_code == 10060 await nc.close() @@ -64,20 +63,20 @@ async def test_publish_verbose(self): js = nc.jetstream() with pytest.raises(NoStreamResponseError): - await js.publish("foo", b'bar') + await js.publish("foo", b"bar") await js.add_stream(name="QUUX", subjects=["quux"]) - ack = await js.publish("quux", b'bar:1', stream="QUUX") + ack = await js.publish("quux", b"bar:1", stream="QUUX") assert ack.stream == "QUUX" assert ack.seq == 1 - ack = await js.publish("quux", b'bar:2') + ack = await js.publish("quux", b"bar:2") assert ack.stream == "QUUX" assert ack.seq == 2 with pytest.raises(BadRequestError) as err: - await js.publish("quux", b'bar', stream="BAR") + await js.publish("quux", b"bar", stream="BAR") assert err.value.err_code == 10060 await nc.close() @@ -92,7 +91,7 @@ async def test_publish_async(self): await js.publish_async_completed() self.assertEqual(js.publish_async_pending(), 0) - future = await js.publish_async("foo", b'bar') + future = await js.publish_async("foo", b"bar") self.assertEqual(js.publish_async_pending(), 1) with pytest.raises(NoStreamResponseError): await future @@ -102,9 +101,7 @@ async def test_publish_async(self): await js.add_stream(name="QUUX", subjects=["quux"]) - futures = [ - await js.publish_async("quux", b'bar:1') for i in range(0, 100) - ] + futures = [await js.publish_async("quux", b"bar:1") for i in range(0, 100)] await js.publish_async_completed() results = await asyncio.gather(*futures) @@ -116,20 +113,25 @@ async def test_publish_async(self): self.assertEqual(result.seq, seq) with pytest.raises(TooManyStalledMsgsError): - publishes = [js.publish_async("quux", b'bar:1', wait_stall=0.0) for i in range(0, 100)] + publishes = [ + js.publish_async("quux", b"bar:1", wait_stall=0.0) + for i in range(0, 100) + ] futures = await asyncio.gather(*publishes) results = await asyncio.gather(*futures) self.assertEqual(len(results), 100) - publishes = [js.publish_async("quux", b'bar:1', wait_stall=1.0) for i in range(0, 1000)] + publishes = [ + js.publish_async("quux", b"bar:1", wait_stall=1.0) for i in range(0, 1000) + ] futures = await asyncio.gather(*publishes) results = await asyncio.gather(*futures) self.assertEqual(len(results), 1000) await nc.close() -class PullSubscribeTest(SingleJetStreamServerTestCase): +class PullSubscribeTest(SingleJetStreamServerTestCase): @async_test async def test_auto_create_consumer(self): nc = NATS() @@ -139,19 +141,19 @@ async def test_auto_create_consumer(self): await js.add_stream(name="TEST2", subjects=["a1", "a2", "a3", "a4"]) for i in range(1, 10): - await js.publish("a1", f'a1:{i}'.encode()) + await js.publish("a1", f"a1:{i}".encode()) # Should use a2 as the filter subject and receive a subset. sub = await js.pull_subscribe("a2", "auto") - await js.publish("a2", b'one') + await js.publish("a2", b"one") for i in range(10, 20): - await js.publish("a3", f'a3:{i}'.encode()) + await js.publish("a3", f"a3:{i}".encode()) msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'one' + assert msg.data == b"one" # Getting another message should timeout for the a2 subject. with pytest.raises(TimeoutError): @@ -164,7 +166,7 @@ async def test_auto_create_consumer(self): msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'one' + assert msg.data == b"one" info = await js.consumer_info("TEST2", "auto2") assert info.config.max_waiting == 10 @@ -173,14 +175,14 @@ async def test_auto_create_consumer(self): msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'a3:10' + assert msg.data == b"a3:10" # Getting all messages from stream. sub = await js.pull_subscribe("", "all") msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'a1:1' + assert msg.data == b"a1:1" for _ in range(2, 10): msgs = await sub.fetch(1) @@ -191,13 +193,13 @@ async def test_auto_create_consumer(self): msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'one' + assert msg.data == b"one" # subject a3 msgs = await sub.fetch(1) msg = msgs[0] await msg.ack() - assert msg.data == b'a3:10' + assert msg.data == b"a3:10" await nc.close() await asyncio.sleep(1) @@ -211,7 +213,7 @@ async def test_fetch_one(self): await js.add_stream(name="TEST1", subjects=["foo.1", "bar"]) - ack = await js.publish("foo.1", f'Hello from NATS!'.encode()) + ack = await js.publish("foo.1", f"Hello from NATS!".encode()) assert ack.stream == "TEST1" assert ack.seq == 1 @@ -232,16 +234,14 @@ async def test_fetch_one(self): await sub.fetch(timeout=1) for i in range(0, 10): - await js.publish( - "foo.1", f"i:{i}".encode(), headers={'hello': 'world'} - ) + await js.publish("foo.1", f"i:{i}".encode(), headers={"hello": "world"}) # nak msgs = await sub.fetch() msg = msgs[0] info = await js.consumer_info("TEST1", "dur", timeout=1) - assert msg.header == {'hello': 'world'} + assert msg.header == {"hello": "world"} await msg.nak() @@ -276,7 +276,7 @@ async def test_fetch_one_wait_forever(self): await js.add_stream(name="TEST111", subjects=["foo.111"]) - ack = await js.publish("foo.111", f'Hello from NATS!'.encode()) + ack = await js.publish("foo.111", f"Hello from NATS!".encode()) assert ack.stream == "TEST111" assert ack.seq == 1 @@ -305,7 +305,7 @@ async def f(): await asyncio.sleep(1) assert received is False await asyncio.sleep(1) - await js.publish("foo.111", b'Hello from NATS!') + await js.publish("foo.111", b"Hello from NATS!") await task assert received @@ -324,9 +324,9 @@ async def test_add_pull_consumer_via_jsm(self): max_deliver=20, max_waiting=512, max_ack_pending=1024, - filter_subject="events.a" + filter_subject="events.a", ) - await js.publish("events.a", b'hello world') + await js.publish("events.a", b"hello world") sub = await js.pull_subscribe_bind("a", stream="events") msgs = await sub.fetch(1) for msg in msgs: @@ -341,13 +341,13 @@ async def test_fetch_n(self): server_version = nc.connected_server_version if server_version.major == 2 and server_version.minor < 9: - pytest.skip('needs to run at least on v2.9.0') + pytest.skip("needs to run at least on v2.9.0") js = nc.jetstream() await js.add_stream(name="TESTN", subjects=["a", "b", "c"]) for i in range(0, 10): - await js.publish("a", f'i:{i}'.encode()) + await js.publish("a", f"i:{i}".encode()) sub = await js.pull_subscribe( "a", @@ -366,7 +366,7 @@ async def test_fetch_n(self): i = 0 for msg in msgs: - assert msg.data == f'i:{i}'.encode() + assert msg.data == f"i:{i}".encode() await msg.ack() i += 1 info = await sub.consumer_info() @@ -381,7 +381,7 @@ async def test_fetch_n(self): i = 5 for msg in msgs: - assert msg.data == f'i:{i}'.encode() + assert msg.data == f"i:{i}".encode() await msg.ack_sync() i += 1 @@ -399,14 +399,14 @@ async def test_fetch_n(self): # ---------- # 0 pending # 1 ack pending - await js.publish("a", b'i:11') + await js.publish("a", b"i:11") msgs = await sub.fetch(2, timeout=0.5) # Leave this message unacked. msg = msgs[0] unacked_msg = msg - assert msg.data == b'i:11' + assert msg.data == b"i:11" info = await sub.consumer_info() assert info.num_waiting < 2 assert info.num_pending == 0 @@ -419,12 +419,12 @@ async def test_fetch_n(self): # 1 extra from before but request has expired so does not count. # +1 ack pending since previous message not acked. # +1 pending to be consumed. - await js.publish("a", b'i:12') + await js.publish("a", b"i:12") # Inspect the internal buffer which should be a 408 at this point. try: msg = await sub._sub.next_msg(timeout=0.5) - assert msg.headers['Status'] == '408' + assert msg.headers["Status"] == "408" except (nats.errors.TimeoutError, TypeError): pass @@ -436,8 +436,8 @@ async def test_fetch_n(self): # Start background task that gathers messages. fut = asyncio.create_task(sub.fetch(3, timeout=2)) await asyncio.sleep(0.5) - await js.publish("a", b'i:13') - await js.publish("a", b'i:14') + await js.publish("a", b"i:13") + await js.publish("a", b"i:14") # It should receive the first one that is available already only. msgs = await fut @@ -467,10 +467,10 @@ async def test_fetch_n(self): # No messages at this point. msgs = await sub.fetch(1, timeout=0.5) - self.assertEqual(msgs[0].data, b'i:13') + self.assertEqual(msgs[0].data, b"i:13") msgs = await sub.fetch(1, timeout=0.5) - self.assertEqual(msgs[0].data, b'i:14') + self.assertEqual(msgs[0].data, b"i:14") with pytest.raises(TimeoutError): await sub.fetch(1, timeout=0.5) @@ -517,11 +517,11 @@ async def test_fetch_max_waiting_fetch_one(self): # assert info.num_waiting == 0 for i in range(0, 10): - await js.publish("max", b'foo') + await js.publish("max", b"foo") async def pub(): while True: - await js.publish("max", b'foo') + await js.publish("max", b"foo") await asyncio.sleep(0) producer = asyncio.create_task(pub()) @@ -597,7 +597,7 @@ async def test_fetch_concurrent(self): async def go_publish(): i = 0 while True: - payload = f'{i}'.encode() + payload = f"{i}".encode() await js.publish("a", payload) i += 1 await asyncio.sleep(0.01) @@ -653,7 +653,7 @@ async def test_fetch_headers(self): js = nc.jetstream() await js.add_stream(name="test-nats", subjects=["test.nats.1"]) - await js.publish("test.nats.1", b'first_msg', headers={'': ''}) + await js.publish("test.nats.1", b"first_msg", headers={"": ""}) sub = await js.pull_subscribe("test.nats.1", "durable") msgs = await sub.fetch(1) assert msgs[0].header == None @@ -664,11 +664,11 @@ async def test_fetch_headers(self): # NOTE: Headers with empty spaces are ignored. await js.publish( "test.nats.1", - b'second_msg', + b"second_msg", headers={ - ' AAA AAA AAA ': ' ', - ' B B B ': ' ' - } + " AAA AAA AAA ": " ", + " B B B ": " ", + }, ) msgs = await sub.fetch(1) assert msgs[0].header == None @@ -679,20 +679,20 @@ async def test_fetch_headers(self): # NOTE: As soon as there is a message with empty spaces are ignored. await js.publish( "test.nats.1", - b'third_msg', + b"third_msg", headers={ - ' AAA-AAA-AAA ': ' a ', - ' AAA-BBB-AAA ': ' ', - ' B B B ': ' a ' - } + " AAA-AAA-AAA ": " a ", + " AAA-BBB-AAA ": " ", + " B B B ": " a ", + }, ) msgs = await sub.fetch(1) - assert msgs[0].header['AAA-AAA-AAA'] == 'a' - assert msgs[0].header['AAA-BBB-AAA'] == '' + assert msgs[0].header["AAA-AAA-AAA"] == "a" + assert msgs[0].header["AAA-BBB-AAA"] == "" msg = await js.get_msg("test-nats", 3) - assert msg.header['AAA-AAA-AAA'] == 'a' - assert msg.header['AAA-BBB-AAA'] == '' + assert msg.header["AAA-AAA-AAA"] == "a" + assert msg.header["AAA-BBB-AAA"] == "" if not parse_email: await nc.close() @@ -700,15 +700,15 @@ async def test_fetch_headers(self): await js.publish( "test.nats.1", - b'third_msg', + b"third_msg", headers={ - ' AAA AAA AAA ': ' a ', - ' AAA-BBB-AAA ': ' b ', - ' B B B ': ' a ' - } + " AAA AAA AAA ": " a ", + " AAA-BBB-AAA ": " b ", + " B B B ": " a ", + }, ) msgs = await sub.fetch(1) - assert msgs[0].header == {'AAA-BBB-AAA': 'b'} + assert msgs[0].header == {"AAA-BBB-AAA": "b"} msg = await js.get_msg("test-nats", 4) assert msg.header == None @@ -730,7 +730,7 @@ async def error_cb(err): await js.add_stream(name="TEST2", subjects=["a1", "a2", "a3", "a4"]) for i in range(1, 10): - await js.publish("a1", f'a1:{i}'.encode()) + await js.publish("a1", f"a1:{i}".encode()) # Shorter msgs limit, and disable bytes limit to not get slow consumers. sub = await js.pull_subscribe( @@ -740,7 +740,7 @@ async def error_cb(err): pending_bytes_limit=-1, ) for i in range(0, 100): - await js.publish("a3", b'test') + await js.publish("a3", b"test") # Internal buffer will drop some of the messages due to reaching limit. msgs = await sub.fetch(100, timeout=1) @@ -768,7 +768,7 @@ async def error_cb(err): assert sub.pending_bytes == 0 # Consumer has a single message pending but none in buffer. - await js.publish("a3", b'last message') + await js.publish("a3", b"last message") info = await sub.consumer_info() assert info.num_pending == 1 assert sub.pending_msgs == 0 @@ -792,6 +792,7 @@ async def test_fetch_cancelled_errors_raised(self): pytest.skip("skip since cannot use AsyncMock") import tracemalloc + tracemalloc.start() nc = NATS() @@ -807,7 +808,7 @@ async def test_fetch_cancelled_errors_raised(self): max_deliver=20, max_waiting=512, max_ack_pending=1024, - filter_subject="test.a" + filter_subject="test.a", ) sub = await js.pull_subscribe("test.a", "test", stream="test") @@ -816,9 +817,9 @@ async def test_fetch_cancelled_errors_raised(self): # is raised here due to the mock usage. with self.assertRaises(asyncio.CancelledError): with unittest.mock.patch( - "asyncio.wait_for", - unittest.mock.AsyncMock(side_effect=asyncio.CancelledError - )): + "asyncio.wait_for", + unittest.mock.AsyncMock(side_effect=asyncio.CancelledError), + ): await sub.fetch(batch=1, timeout=0.1) await nc.close() @@ -833,7 +834,7 @@ async def test_ephemeral_pull_subscribe(self): await js.add_stream(name=subject, subjects=["a1", "a2", "a3", "a4"]) for i in range(1, 10): - await js.publish("a1", f'a1:{i}'.encode()) + await js.publish("a1", f"a1:{i}".encode()) with pytest.raises(NotFoundError): await js.pull_subscribe("a0") @@ -856,10 +857,10 @@ async def test_consumer_with_multiple_filters(self): # Create stream. await jsm.add_stream(name="ctests", subjects=["a", "b", "c.>"]) - await js.publish("a", b'A') - await js.publish("b", b'B') - await js.publish("c.d", b'CD') - await js.publish("c.d.e", b'CDE') + await js.publish("a", b"A") + await js.publish("b", b"B") + await js.publish("c.d", b"CD") + await js.publish("c.d.e", b"CDE") # Create ephemeral pull consumer with a name. stream_name = "ctests" @@ -875,17 +876,17 @@ async def test_consumer_with_multiple_filters(self): sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'A' + assert msgs[0].data == b"A" ok = await msgs[0].ack_sync() assert ok msgs = await sub.fetch(1) - assert msgs[0].data == b'B' + assert msgs[0].data == b"B" ok = await msgs[0].ack_sync() assert ok msgs = await sub.fetch(1) - assert msgs[0].data == b'CDE' + assert msgs[0].data == b"CDE" ok = await msgs[0].ack_sync() assert ok @@ -908,7 +909,7 @@ async def test_add_consumer_with_backoff(self): filter_subject="events.>", ) for i in range(0, 3): - await js.publish("events.%d" % i, b'i:%d' % i) + await js.publish("events.%d" % i, b"i:%d" % i) sub = await js.pull_subscribe_bind("a", stream="events") events_prefix = "$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES." @@ -927,7 +928,7 @@ async def cb(msg): try: msgs = await sub.fetch(1, timeout=5) for msg in msgs: - if msg.subject == 'events.0': + if msg.subject == "events.0": received.append((time.monotonic(), msg)) except TimeoutError as err: # There should be no timeout as redeliveries should happen faster. @@ -977,7 +978,7 @@ async def test_fetch_heartbeats(self): await sub.fetch(1, timeout=1, heartbeat=0.1) for i in range(0, 15): - await js.publish("events.%d" % i, b'i:%d' % i) + await js.publish("events.%d" % i, b"i:%d" % i) # Fetch(n) msgs = await sub.fetch(5, timeout=5, heartbeat=0.1) @@ -1017,7 +1018,7 @@ async def test_fetch_heartbeats(self): with pytest.raises(nats.js.errors.APIError) as err: await sub.fetch(1, timeout=1, heartbeat=0.5) - assert err.value.description == 'Bad Request - heartbeat value too large' + assert err.value.description == "Bad Request - heartbeat value too large" # Example of catching fetch timeout instead first. got_fetch_timeout = False @@ -1046,7 +1047,6 @@ async def test_fetch_heartbeats(self): class JSMTest(SingleJetStreamServerTestCase): - @async_test async def test_stream_management(self): nc = NATS() @@ -1072,9 +1072,7 @@ async def test_stream_management(self): with pytest.raises(ValueError): await jsm.add_stream(nats.js.api.StreamConfig()) # Create with config, name is provided as kwargs - stream_with_name = await jsm.add_stream( - nats.js.api.StreamConfig(), name="hi" - ) + stream_with_name = await jsm.add_stream(nats.js.api.StreamConfig(), name="hi") assert stream_with_name.config.name == "hi" # Get info @@ -1090,7 +1088,7 @@ async def test_stream_management(self): # Send messages producer = nc.jetstream() - ack = await producer.publish('world', b'Hello world!') + ack = await producer.publish("world", b"Hello world!") assert ack.stream == "hello" assert ack.seq == 1 @@ -1101,9 +1099,7 @@ async def test_stream_management(self): stream_config = current.config stream_config.subjects.append("extra") updated_stream = await jsm.update_stream(stream_config) - assert updated_stream.config.subjects == [ - 'hello', 'world', 'hello.>', 'extra' - ] + assert updated_stream.config.subjects == ["hello", "world", "hello.>", "extra"] # Purge Stream is_purged = await jsm.purge_stream("hello") @@ -1186,24 +1182,24 @@ async def test_jsm_get_delete_msg(self): # Create stream stream = await jsm.add_stream(name="foo", subjects=["foo.>"]) - await js.publish("foo.a.1", b'Hello', headers={'foo': 'bar'}) - await js.publish("foo.b.1", b'World') - await js.publish("foo.c.1", b'!!!') + await js.publish("foo.a.1", b"Hello", headers={"foo": "bar"}) + await js.publish("foo.b.1", b"World") + await js.publish("foo.c.1", b"!!!") # GetMsg msg = await jsm.get_msg("foo", 2) - assert msg.subject == 'foo.b.1' - assert msg.data == b'World' + assert msg.subject == "foo.b.1" + assert msg.data == b"World" msg = await jsm.get_msg("foo", 3) - assert msg.subject == 'foo.c.1' - assert msg.data == b'!!!' + assert msg.subject == "foo.c.1" + assert msg.data == b"!!!" msg = await jsm.get_msg("foo", 1) - assert msg.subject == 'foo.a.1' - assert msg.data == b'Hello' + assert msg.subject == "foo.a.1" + assert msg.data == b"Hello" assert msg.headers["foo"] == "bar" - assert msg.hdrs == 'TkFUUy8xLjANCmZvbzogYmFyDQoNCg==' + assert msg.hdrs == "TkFUUy8xLjANCmZvbzogYmFyDQoNCg==" with pytest.raises(BadRequestError): await jsm.get_msg("foo", 0) @@ -1307,12 +1303,10 @@ async def test_number_of_consumer_replicas(self): js = nc.jetstream() await js.add_stream(name="TESTREPLICAS", subjects=["test.replicas"]) for i in range(0, 10): - await js.publish("test.replicas", f'{i}'.encode()) + await js.publish("test.replicas", f"{i}".encode()) # Create consumer - config = nats.js.api.ConsumerConfig( - num_replicas=1, durable_name="mycons" - ) + config = nats.js.api.ConsumerConfig(num_replicas=1, durable_name="mycons") cons = await js.add_consumer(stream="TESTREPLICAS", config=config) if cons.config.num_replicas: assert cons.config.num_replicas == 1 @@ -1330,10 +1324,10 @@ async def test_consumer_with_name(self): # Create stream. await jsm.add_stream(name="ctests", subjects=["a", "b", "c.>"]) - await js.publish("a", b'hello world!') - await js.publish("b", b'hello world!!') - await js.publish("c.d", b'hello world!!!') - await js.publish("c.d.e", b'hello world!!!!') + await js.publish("a", b"hello world!") + await js.publish("b", b"hello world!!") + await js.publish("c.d", b"hello world!!!") + await js.publish("c.d.e", b"hello world!!!!") # Create ephemeral pull consumer with a name. stream_name = "ctests" @@ -1346,16 +1340,16 @@ async def test_consumer_with_name(self): assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.ephemeral' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.ephemeral" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.ephemeral' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.ephemeral" # Create durable pull consumer with a name. consumer_name = "durable" @@ -1367,15 +1361,15 @@ async def test_consumer_with_name(self): ) assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.durable' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.durable" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.durable' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.durable" # Create durable pull consumer with a name and a filter_subject consumer_name = "durable2" @@ -1388,15 +1382,15 @@ async def test_consumer_with_name(self): ) assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.durable2.b' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.durable2.b" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!!' + assert msgs[0].data == b"hello world!!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.durable2' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.durable2" # Create durable pull consumer with a name and a filter_subject consumer_name = "durable3" @@ -1409,15 +1403,15 @@ async def test_consumer_with_name(self): ) assert cinfo.config.name == consumer_name msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.CREATE.ctests.durable3' + assert msg.subject == "$JS.API.CONSUMER.CREATE.ctests.durable3" sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok msg = await tsub.next_msg() - assert msg.subject == '$JS.API.CONSUMER.MSG.NEXT.ctests.durable3' + assert msg.subject == "$JS.API.CONSUMER.MSG.NEXT.ctests.durable3" # name and durable must match if both present. with pytest.raises(BadRequestError) as err: @@ -1428,7 +1422,10 @@ async def test_consumer_with_name(self): ack_policy="explicit", ) assert err.value.err_code == 10017 - assert err.value.description == 'consumer name in subject does not match durable name in request' + assert ( + err.value.description + == "consumer name in subject does not match durable name in request" + ) # Create ephemeral pull consumer with a name and inactive threshold. stream_name = "ctests" @@ -1443,7 +1440,7 @@ async def test_consumer_with_name(self): sub = await js.pull_subscribe_bind(consumer_name, stream_name) msgs = await sub.fetch(1) - assert msgs[0].data == b'hello world!' + assert msgs[0].data == b"hello world!" ok = await msgs[0].ack_sync() assert ok @@ -1463,16 +1460,16 @@ async def test_jsm_stream_info_options(self): stream = await jsm.add_stream(name="foo", subjects=["foo.>"]) for i in range(0, 5): - await js.publish("foo.%d" % i, b'A') + await js.publish("foo.%d" % i, b"A") si = await jsm.stream_info("foo", subjects_filter=">") assert si.state.messages == 5 assert si.state.subjects == { - 'foo.0': 1, - 'foo.1': 1, - 'foo.2': 1, - 'foo.3': 1, - 'foo.4': 1 + "foo.0": 1, + "foo.1": 1, + "foo.2": 1, + "foo.3": 1, + "foo.4": 1, } # When nothing matches streams subject will be empty. @@ -1487,7 +1484,6 @@ async def test_jsm_stream_info_options(self): class SubscribeTest(SingleJetStreamServerTestCase): - @async_test async def test_queue_subscribe_deliver_group(self): nc = await nats.connect() @@ -1524,7 +1520,7 @@ async def cb3(msg): subs.append(sub3) for i in range(100): - await js.publish("quux", f'Hello World {i}'.encode()) + await js.publish("quux", f"Hello World {i}".encode()) delivered = [a, b, c] for _ in range(5): @@ -1545,7 +1541,7 @@ async def cb3(msg): # Confirm that no more messages are received. for _ in range(200, 210): - await js.publish("quux", f'Hello World {i}'.encode()) + await js.publish("quux", f"Hello World {i}".encode()) with pytest.raises(BadSubscriptionError): await sub1.unsubscribe() @@ -1571,7 +1567,7 @@ async def cb2(msg): b.append(msg) for i in range(10): - await js.publish("pbound", f'Hello World {i}'.encode()) + await js.publish("pbound", f"Hello World {i}".encode()) # First subscriber will create. await js.subscribe("pbound", cb=cb1, durable="singleton") @@ -1586,9 +1582,7 @@ async def cb2(msg): assert err.value.description == "consumer is already bound to a subscription" with pytest.raises(nats.js.errors.Error) as err: - await js.subscribe( - "pbound", queue="foo", cb=cb2, durable="singleton" - ) + await js.subscribe("pbound", queue="foo", cb=cb2, durable="singleton") exp = "cannot create queue subscription 'foo' to consumer 'singleton'" assert err.value.description == exp @@ -1602,7 +1596,7 @@ async def cb2(msg): # Create a sync subscriber now. sub2 = await js.subscribe("pbound", durable="two") msg = await sub2.next_msg() - assert msg.data == b'Hello World 0' + assert msg.data == b"Hello World 0" assert msg.metadata.sequence.stream == 1 assert msg.metadata.sequence.consumer == 1 assert sub2.pending_msgs == 9 @@ -1618,7 +1612,7 @@ async def test_ephemeral_subscribe(self): await js.add_stream(name=subject, subjects=[subject]) for i in range(10): - await js.publish(subject, f'Hello World {i}'.encode()) + await js.publish(subject, f"Hello World {i}".encode()) # First subscriber will create. sub1 = await js.subscribe(subject) @@ -1677,7 +1671,7 @@ async def test_subscribe_bind(self): config=consumer_info.config, ) for i in range(10): - await js.publish(subject_name, f'Hello World {i}'.encode()) + await js.publish(subject_name, f"Hello World {i}".encode()) msgs = [] for i in range(0, 10): @@ -1739,7 +1733,7 @@ async def cb3(msg): subs.append(sub3) for i in range(100): - await js.publish("quux", f'Hello World {i}'.encode()) + await js.publish("quux", f"Hello World {i}".encode()) # Only 3rd group should have ran into slow consumer errors. await asyncio.sleep(1) @@ -1761,11 +1755,18 @@ async def test_unsubscribe_coroutines(self): coroutines_before = len(asyncio.all_tasks()) a = [] + async def cb(msg): a.append(msg) # Create a PushSubscription - sub = await js.subscribe("quux", "wg", cb=cb, ordered_consumer=True, config=nats.js.api.ConsumerConfig(idle_heartbeat=5)) + sub = await js.subscribe( + "quux", + "wg", + cb=cb, + ordered_consumer=True, + config=nats.js.api.ConsumerConfig(idle_heartbeat=5), + ) # breakpoint() @@ -1788,8 +1789,8 @@ async def cb(msg): self.assertEqual(coroutines_before, coroutines_after_unsubscribe) self.assertNotEqual(coroutines_before, coroutines_after_subscribe) -class AckPolicyTest(SingleJetStreamServerTestCase): +class AckPolicyTest(SingleJetStreamServerTestCase): @async_test async def test_ack_v2_tokens(self): nc = await nats.connect() @@ -1830,17 +1831,9 @@ async def test_ack_v2_tokens(self): assert meta.sequence.consumer == consumer_sequence assert meta.num_delivered == num_delivered assert meta.num_pending == num_pending - assert meta.timestamp.astimezone(datetime.timezone.utc - ) == datetime.datetime( - 2022, - 9, - 11, - 0, - 28, - 27, - 340506, - tzinfo=datetime.timezone.utc - ) + assert meta.timestamp.astimezone(datetime.timezone.utc) == datetime.datetime( + 2022, 9, 11, 0, 28, 27, 340506, tzinfo=datetime.timezone.utc + ) @async_test async def test_double_acking_pull_subscribe(self): @@ -1849,7 +1842,7 @@ async def test_double_acking_pull_subscribe(self): js = nc.jetstream() await js.add_stream(name="TESTACKS", subjects=["test"]) for i in range(0, 10): - await js.publish("test", f'{i}'.encode()) + await js.publish("test", f"{i}".encode()) # Pull Subscriber psub = await js.pull_subscribe("test", "durable") @@ -1904,7 +1897,7 @@ async def error_handler(e): js = nc.jetstream() await js.add_stream(name="TESTACKS", subjects=["test"]) for i in range(0, 10): - await js.publish("test", f'{i}'.encode()) + await js.publish("test", f"{i}".encode()) async def ocb(msg): # Ack the first one only. @@ -1941,7 +1934,7 @@ async def test_nak_delay(self): js = nc.jetstream() await js.add_stream(name=stream, subjects=[subject]) - await js.publish(subject, b'message') + await js.publish(subject, b"message") psub = await js.pull_subscribe(subject, "durable") msgs = await psub.fetch(1) @@ -1969,109 +1962,110 @@ async def f(): assert task.done() assert received + class DiscardPolicyTest(SingleJetStreamServerTestCase): - @async_test - async def test_with_discard_new_and_discard_new_per_subject_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO0" - config = nats.js.api.StreamConfig( - name=stream_name, - discard=nats.js.api.DiscardPolicy.NEW, - discard_new_per_subject=True, - max_msgs_per_subject=100 - ) - - await js.add_stream(config) - info = await js.stream_info(stream_name) - self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) - self.assertEqual(info.config.discard_new_per_subject, True) - - # Close the NATS connection after the test - await nc.close() - - @async_test - async def test_with_discard_new_and_discard_new_per_subject_not_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO1" - config = nats.js.api.StreamConfig( - name=stream_name, - discard=nats.js.api.DiscardPolicy.NEW, - discard_new_per_subject=False, - max_msgs_per_subject=100 - ) - - await js.add_stream(config) - info = await js.stream_info(stream_name) - self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) - self.assertEqual(info.config.discard_new_per_subject, False) - - await nc.close() - - @async_test - async def test_with_discard_old_and_discard_new_per_subject_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO2" - config = nats.js.api.StreamConfig( - name=stream_name, - discard='DiscardOld', - discard_new_per_subject=True, - max_msgs_per_subject=100 - ) - - with self.assertRaises(APIError): - await js.add_stream(config) - - # Close the NATS connection after the test - await nc.close() - - @async_test - async def test_with_discard_old_and_discard_new_per_subject_not_set(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO3" - config = nats.js.api.StreamConfig( - name=stream_name, - discard='DiscardOld', - discard_new_per_subject=True, - max_msgs_per_subject=100 - ) - - with self.assertRaises(APIError): - await js.add_stream(config) - - await nc.close() - - @async_test - async def test_with_discard_new_and_discard_new_per_subject_set_no_max_msgs(self): - # Connect to NATS and create JetStream context - nc = await nats.connect() - js = nc.jetstream() - - stream_name = "FOO4" - config = nats.js.api.StreamConfig( - name=stream_name, - discard=nats.js.api.DiscardPolicy.NEW, - discard_new_per_subject=True - ) - - with self.assertRaises(APIError): - await js.add_stream(config) - - await nc.close() + @async_test + async def test_with_discard_new_and_discard_new_per_subject_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() -class OrderedConsumerTest(SingleJetStreamServerTestCase): + stream_name = "FOO0" + config = nats.js.api.StreamConfig( + name=stream_name, + discard=nats.js.api.DiscardPolicy.NEW, + discard_new_per_subject=True, + max_msgs_per_subject=100, + ) + + await js.add_stream(config) + info = await js.stream_info(stream_name) + self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) + self.assertEqual(info.config.discard_new_per_subject, True) + + # Close the NATS connection after the test + await nc.close() + @async_test + async def test_with_discard_new_and_discard_new_per_subject_not_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO1" + config = nats.js.api.StreamConfig( + name=stream_name, + discard=nats.js.api.DiscardPolicy.NEW, + discard_new_per_subject=False, + max_msgs_per_subject=100, + ) + + await js.add_stream(config) + info = await js.stream_info(stream_name) + self.assertEqual(info.config.discard, nats.js.api.DiscardPolicy.NEW) + self.assertEqual(info.config.discard_new_per_subject, False) + + await nc.close() + + @async_test + async def test_with_discard_old_and_discard_new_per_subject_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO2" + config = nats.js.api.StreamConfig( + name=stream_name, + discard="DiscardOld", + discard_new_per_subject=True, + max_msgs_per_subject=100, + ) + + with self.assertRaises(APIError): + await js.add_stream(config) + + # Close the NATS connection after the test + await nc.close() + + @async_test + async def test_with_discard_old_and_discard_new_per_subject_not_set(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO3" + config = nats.js.api.StreamConfig( + name=stream_name, + discard="DiscardOld", + discard_new_per_subject=True, + max_msgs_per_subject=100, + ) + + with self.assertRaises(APIError): + await js.add_stream(config) + + await nc.close() + + @async_test + async def test_with_discard_new_and_discard_new_per_subject_set_no_max_msgs(self): + # Connect to NATS and create JetStream context + nc = await nats.connect() + js = nc.jetstream() + + stream_name = "FOO4" + config = nats.js.api.StreamConfig( + name=stream_name, + discard=nats.js.api.DiscardPolicy.NEW, + discard_new_per_subject=True, + ) + + with self.assertRaises(APIError): + await js.add_stream(config) + + await nc.close() + + +class OrderedConsumerTest(SingleJetStreamServerTestCase): @async_test async def test_flow_control(self): errors = [] @@ -2092,17 +2086,17 @@ async def cb(msg): with pytest.raises(nats.js.errors.APIError) as err: sub = await js.subscribe(subject, cb=cb, flow_control=True) - assert err.value.description == "consumer with flow control also needs heartbeats" - - sub = await js.subscribe( - subject, cb=cb, flow_control=True, idle_heartbeat=0.5 + assert ( + err.value.description == "consumer with flow control also needs heartbeats" ) + sub = await js.subscribe(subject, cb=cb, flow_control=True, idle_heartbeat=0.5) + tasks = [] async def producer(): mlen = 128 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2112,7 +2106,7 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize task = asyncio.create_task(js.publish(subject, chunk)) tasks.append(task) @@ -2159,7 +2153,7 @@ async def cb(msg): async def producer(): mlen = 1024 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2169,10 +2163,10 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize task = asyncio.create_task( - js.publish(subject, chunk, headers={'data': "true"}) + js.publish(subject, chunk, headers={"data": "true"}) ) await asyncio.sleep(0) tasks.append(task) @@ -2182,7 +2176,7 @@ async def producer(): await asyncio.gather(*tasks) assert len(msgs) == 1024 - received_payload = bytearray(b'') + received_payload = bytearray(b"") for msg in msgs: received_payload.extend(msg.data) assert len(received_payload) == expected_size @@ -2236,9 +2230,7 @@ def _build_message(sid, subject, reply, data, headers): nc._build_message = _build_message subject = "osub2" - await js2.add_stream( - name=subject, subjects=[subject], storage="memory" - ) + await js2.add_stream(name=subject, subjects=[subject], storage="memory") # Consumer callback. future = asyncio.Future() @@ -2263,7 +2255,7 @@ async def cb(msg): async def producer(): mlen = 1024 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2273,10 +2265,10 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize task = asyncio.create_task( - nc2.publish(subject, chunk, headers={'data': "true"}) + nc2.publish(subject, chunk, headers={"data": "true"}) ) tasks.append(task) @@ -2292,7 +2284,7 @@ async def producer(): await asyncio.wait_for(future, timeout=5) assert len(msgs) == 1024 - received_payload = bytearray(b'') + received_payload = bytearray(b"") for msg in msgs: received_payload.extend(msg.data) assert len(received_payload) == expected_size @@ -2330,14 +2322,14 @@ async def error_handler(e): js = nc.jetstream() js2 = nc2.jetstream() - subject = 'ORDERS' + subject = "ORDERS" await js2.add_stream(name=subject, subjects=[subject], storage="file") tasks = [] async def producer(): mlen = 10 * 1024 * 1024 - msg = b'A' * mlen + msg = b"A" * mlen # Send it in chunks i = 0 @@ -2347,10 +2339,10 @@ async def producer(): if mlen - i <= chunksize: chunk = msg[i:] else: - chunk = msg[i:i + chunksize] + chunk = msg[i : i + chunksize] i += chunksize task = asyncio.create_task( - nc2.publish(subject, chunk, headers={'data': "true"}) + nc2.publish(subject, chunk, headers={"data": "true"}) ) tasks.append(task) @@ -2367,7 +2359,7 @@ async def producer(): async def cb(msg): nonlocal i nonlocal done - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") i += 1 if i == stream.state.messages: if not done.done(): @@ -2379,14 +2371,12 @@ async def cb(msg): await asyncio.wait_for(done, 10) # Using only next_msg which would be slower. - sub = await js.subscribe( - subject, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, ordered_consumer=True, idle_heartbeat=0.5) i = 0 while i < stream.state.messages: try: msg = await sub.next_msg() - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") if i % 1000 == 0: await asyncio.sleep(0) i += 1 @@ -2396,9 +2386,7 @@ async def cb(msg): ###################### # Reconnecting # ###################### - sub = await js.subscribe( - subject, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, ordered_consumer=True, idle_heartbeat=0.5) i = 0 while i < stream.state.messages: if i == 5000: @@ -2411,7 +2399,7 @@ async def cb(msg): ) try: msg = await sub.next_msg() - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") if i % 1000 == 0: await asyncio.sleep(0) i += 1 @@ -2434,7 +2422,7 @@ async def cb(msg): None, self.server_pool[0].start ) - data = msg.data.decode('utf-8') + data = msg.data.decode("utf-8") i += 1 if i == stream.state.messages: if not done.done(): @@ -2457,22 +2445,16 @@ async def error_handler(e): nc = await nats.connect(error_cb=error_handler) js = nc.jetstream() - await js.add_stream( - name="MY_STREAM", subjects=["test.*"], storage="memory" - ) + await js.add_stream(name="MY_STREAM", subjects=["test.*"], storage="memory") subject = "test.1" - for m in ['1', '2', '3']: + for m in ["1", "2", "3"]: await js.publish(subject=subject, payload=m.encode()) - sub = await js.subscribe( - subject, ordered_consumer=True, idle_heartbeat=0.5 - ) + sub = await js.subscribe(subject, ordered_consumer=True, idle_heartbeat=0.5) info = await sub.consumer_info() orig_name = info.name await js.delete_consumer("MY_STREAM", info.name) - await asyncio.sleep( - 3 - ) # now the consumer should reset due to missing HB + await asyncio.sleep(3) # now the consumer should reset due to missing HB info = await sub.consumer_info() self.assertTrue(orig_name != info.name) @@ -2480,7 +2462,6 @@ async def error_handler(e): class KVTest(SingleJetStreamServerTestCase): - @async_test async def test_kv_simple(self): errors = [] @@ -2499,18 +2480,18 @@ async def error_handler(e): assert status.history == 5 assert status.ttl == 3600 - seq = await kv.put("hello", b'world') + seq = await kv.put("hello", b"world") assert seq == 1 entry = await kv.get("hello") assert "hello" == entry.key - assert b'world' == entry.value + assert b"world" == entry.value status = await kv.status() assert status.values == 1 for i in range(0, 100): - await kv.put(f"hello.{i}", b'Hello JS KV!') + await kv.put(f"hello.{i}", b"Hello JS KV!") status = await kv.status() assert status.values == 101 @@ -2524,10 +2505,10 @@ async def error_handler(e): with pytest.raises(KeyNotFoundError) as err: await kv.get("hello.1") - assert err.value.entry.key == 'hello.1' + assert err.value.entry.key == "hello.1" assert err.value.entry.revision == 102 assert err.value.entry.value == None - assert err.value.op == 'DEL' + assert err.value.op == "DEL" await kv.purge("hello.5") with pytest.raises(KeyNotFoundError) as err: @@ -2549,21 +2530,21 @@ async def error_handler(e): entry = await kv.get("hello") assert "hello" == entry.key - assert b'world' == entry.value + assert b"world" == entry.value assert 1 == entry.revision # Now get the same KV via lookup. kv = await js.key_value("TEST") entry = await kv.get("hello") assert "hello" == entry.key - assert b'world' == entry.value + assert b"world" == entry.value assert 1 == entry.revision status = await kv.status() assert status.values == 1 for i in range(100, 200): - await kv.put(f"hello.{i}", b'Hello JS KV!') + await kv.put(f"hello.{i}", b"Hello JS KV!") status = await kv.status() assert status.values == 101 @@ -2573,7 +2554,7 @@ async def error_handler(e): entry = await kv.get("hello.102") assert "hello.102" == entry.key - assert b'Hello JS KV!' == entry.value + assert b"Hello JS KV!" == entry.value assert 107 == entry.revision await nc.close() @@ -2610,24 +2591,20 @@ async def error_handler(e): bucket = "TEST" kv = await js.create_key_value( - bucket=bucket, - history=5, - ttl=3600, - description="Basic KV", - direct=False + bucket=bucket, history=5, ttl=3600, description="Basic KV", direct=False ) status = await kv.status() si = await js.stream_info("KV_TEST") config = si.config assert config.description == "Basic KV" - assert config.subjects == ['$KV.TEST.>'] + assert config.subjects == ["$KV.TEST.>"] # Check server version for some of these. assert config.allow_rollup_hdrs == True assert config.deny_delete == True assert config.deny_purge == False - assert config.discard == 'new' + assert config.discard == "new" assert config.duplicate_window == 120.0 assert config.max_age == 3600.0 assert config.max_bytes == -1 @@ -2639,10 +2616,10 @@ async def error_handler(e): assert config.no_ack == False assert config.num_replicas == 1 assert config.placement == None - assert config.retention == 'limits' + assert config.retention == "limits" assert config.sealed == False assert config.sources == None - assert config.storage == 'file' + assert config.storage == "file" assert config.template_owner == None version = nc.connected_server_version @@ -2656,13 +2633,13 @@ async def error_handler(e): await kv.get(f"name") # Simple Put - revision = await kv.put(f"name", b'alice') + revision = await kv.put(f"name", b"alice") assert revision == 1 # Simple Get result = await kv.get(f"name") assert result.revision == 1 - assert result.value == b'alice' + assert result.value == b"alice" # Delete ok = await kv.delete(f"name") @@ -2674,7 +2651,7 @@ async def error_handler(e): await kv.get(f"name") # Recreate with different name. - revision = await kv.create("name", b'bob') + revision = await kv.create("name", b"bob") assert revision == 3 # Expect last revision to be 4 @@ -2698,36 +2675,37 @@ async def error_handler(e): assert revision == 6 # Create a different key. - revision = await kv.create("age", b'2038') + revision = await kv.create("age", b"2038") assert revision == 7 # Get current. entry = await kv.get("age") - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Update the new key. - revision = await kv.update("age", b'2039', last=revision) + revision = await kv.update("age", b"2039", last=revision) assert revision == 8 # Get latest. entry = await kv.get("age") - assert entry.value == b'2039' + assert entry.value == b"2039" assert entry.revision == 8 # Internally uses get msg API instead of get last msg. entry = await kv.get("age", revision=7) - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Getting past keys with the wrong expected subject is an error. with pytest.raises(KeyNotFoundError) as err: entry = await kv.get("age", revision=6) - assert entry.value == b'fuga' + assert entry.value == b"fuga" assert entry.revision == 6 - assert str( - err.value - ) == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + assert ( + str(err.value) + == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + ) with pytest.raises(KeyNotFoundError) as err: await kv.get("age", revision=5) @@ -2736,19 +2714,21 @@ async def error_handler(e): await kv.get("age", revision=4) entry = await kv.get("name", revision=3) - assert entry.value == b'bob' + assert entry.value == b"bob" - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 8"): - await kv.create("age", b'1') + with pytest.raises( + KeyWrongLastSequenceError, match="nats: wrong last sequence: 8" + ): + await kv.create("age", b"1") # Now let's delete and recreate. await kv.delete("age", last=8) - await kv.create("age", b'final') + await kv.create("age", b"final") - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 10"): - await kv.create("age", b'1') + with pytest.raises( + KeyWrongLastSequenceError, match="nats: wrong last sequence: 10" + ): + await kv.create("age", b"1") entry = await kv.get("age") assert entry.revision == 10 @@ -2771,42 +2751,38 @@ async def error_handler(e): bucket = "TEST" kv = await js.create_key_value( - bucket=bucket, - history=5, - ttl=3600, - description="Direct KV", - direct=True + bucket=bucket, history=5, ttl=3600, description="Direct KV", direct=True ) si = await js.stream_info("KV_TEST") config = si.config assert config.description == "Direct KV" - assert config.subjects == ['$KV.TEST.>'] - await kv.create("A", b'1') - await kv.create("B", b'2') - await kv.create("C", b'3') - await kv.create("D", b'4') - await kv.create("E", b'5') - await kv.create("F", b'6') - - await kv.put("C", b'33') - await kv.put("D", b'44') - await kv.put("C", b'333') + assert config.subjects == ["$KV.TEST.>"] + await kv.create("A", b"1") + await kv.create("B", b"2") + await kv.create("C", b"3") + await kv.create("D", b"4") + await kv.create("E", b"5") + await kv.create("F", b"6") + + await kv.put("C", b"33") + await kv.put("D", b"44") + await kv.put("C", b"333") # Check with low level msg APIs. msg = await js.get_msg("KV_TEST", seq=1, direct=True) - assert msg.data == b'1' + assert msg.data == b"1" # last by subject msg = await js.get_msg("KV_TEST", subject="$KV.TEST.C", direct=True) - assert msg.data == b'333' + assert msg.data == b"333" # next by subject msg = await js.get_msg( "KV_TEST", seq=4, next=True, subject="$KV.TEST.C", direct=True ) - assert msg.data == b'33' + assert msg.data == b"33" @async_test async def test_kv_direct(self): @@ -2829,7 +2805,7 @@ async def error_handler(e): history=5, ttl=3600, description="Explicit Direct KV", - direct=True + direct=True, ) kv = await js.key_value(bucket=bucket) status = await kv.status() @@ -2837,14 +2813,14 @@ async def error_handler(e): si = await js.stream_info("KV_TEST") config = si.config assert config.description == "Explicit Direct KV" - assert config.subjects == ['$KV.TEST.>'] + assert config.subjects == ["$KV.TEST.>"] # Check server version for some of these. assert config.allow_rollup_hdrs == True assert config.allow_direct == True assert config.deny_delete == True assert config.deny_purge == False - assert config.discard == 'new' + assert config.discard == "new" assert config.duplicate_window == 120.0 assert config.max_age == 3600.0 assert config.max_bytes == -1 @@ -2856,10 +2832,10 @@ async def error_handler(e): assert config.no_ack == False assert config.num_replicas == 1 assert config.placement == None - assert config.retention == 'limits' + assert config.retention == "limits" assert config.sealed == False assert config.sources == None - assert config.storage == 'file' + assert config.storage == "file" assert config.template_owner == None # Nothing from start @@ -2867,13 +2843,13 @@ async def error_handler(e): await kv.get(f"name") # Simple Put - revision = await kv.put(f"name", b'alice') + revision = await kv.put(f"name", b"alice") assert revision == 1 # Simple Get result = await kv.get(f"name") assert result.revision == 1 - assert result.value == b'alice' + assert result.value == b"alice" # Delete ok = await kv.delete(f"name") @@ -2885,7 +2861,7 @@ async def error_handler(e): await kv.get(f"name") # Recreate with different name. - revision = await kv.create("name", b'bob') + revision = await kv.create("name", b"bob") assert revision == 3 # Expect last revision to be 4 @@ -2909,36 +2885,37 @@ async def error_handler(e): assert revision == 6 # Create a different key. - revision = await kv.create("age", b'2038') + revision = await kv.create("age", b"2038") assert revision == 7 # Get current. entry = await kv.get("age") - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Update the new key. - revision = await kv.update("age", b'2039', last=revision) + revision = await kv.update("age", b"2039", last=revision) assert revision == 8 # Get latest. entry = await kv.get("age") - assert entry.value == b'2039' + assert entry.value == b"2039" assert entry.revision == 8 # Internally uses get msg API instead of get last msg. entry = await kv.get("age", revision=7) - assert entry.value == b'2038' + assert entry.value == b"2038" assert entry.revision == 7 # Getting past keys with the wrong expected subject is an error. with pytest.raises(KeyNotFoundError) as err: entry = await kv.get("age", revision=6) - assert entry.value == b'fuga' + assert entry.value == b"fuga" assert entry.revision == 6 - assert str( - err.value - ) == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + assert ( + str(err.value) + == "nats: key not found: expected '$KV.TEST.age', but got '$KV.TEST.name'" + ) with pytest.raises(KeyNotFoundError) as err: await kv.get("age", revision=5) @@ -2947,19 +2924,21 @@ async def error_handler(e): await kv.get("age", revision=4) entry = await kv.get("name", revision=3) - assert entry.value == b'bob' + assert entry.value == b"bob" - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 8"): - await kv.create("age", b'1') + with pytest.raises( + KeyWrongLastSequenceError, match="nats: wrong last sequence: 8" + ): + await kv.create("age", b"1") # Now let's delete and recreate. await kv.delete("age", last=8) - await kv.create("age", b'final') + await kv.create("age", b"final") - with pytest.raises(KeyWrongLastSequenceError, - match="nats: wrong last sequence: 10"): - await kv.create("age", b'1') + with pytest.raises( + KeyWrongLastSequenceError, match="nats: wrong last sequence: 10" + ): + await kv.create("age", b"1") entry = await kv.get("age") assert entry.revision == 10 @@ -2967,7 +2946,9 @@ async def error_handler(e): with pytest.raises(Error) as err: await js.add_stream(name="mirror", mirror_direct=True) assert err.value.err_code == 10052 - assert err.value.description == 'stream has no mirror but does have mirror direct' + assert ( + err.value.description == "stream has no mirror but does have mirror direct" + ) await nc.close() @@ -2993,51 +2974,51 @@ async def error_handler(e): e = await w.updates(timeout=1) assert e is None - await kv.create("name", b'alice:1') + await kv.create("name", b"alice:1") e = await w.updates() assert e.delta == 0 - assert e.key == 'name' - assert e.value == b'alice:1' + assert e.key == "name" + assert e.value == b"alice:1" assert e.revision == 1 - await kv.put("name", b'alice:2') + await kv.put("name", b"alice:2") e = await w.updates() - assert e.key == 'name' - assert e.value == b'alice:2' + assert e.key == "name" + assert e.value == b"alice:2" assert e.revision == 2 - await kv.put("name", b'alice:3') + await kv.put("name", b"alice:3") e = await w.updates() - assert e.key == 'name' - assert e.value == b'alice:3' + assert e.key == "name" + assert e.value == b"alice:3" assert e.revision == 3 - await kv.put("age", b'22') + await kv.put("age", b"22") e = await w.updates() - assert e.key == 'age' - assert e.value == b'22' + assert e.key == "age" + assert e.value == b"22" assert e.revision == 4 - await kv.put("age", b'33') + await kv.put("age", b"33") e = await w.updates() assert e.bucket == "WATCH" - assert e.key == 'age' - assert e.value == b'33' + assert e.key == "age" + assert e.value == b"33" assert e.revision == 5 await kv.delete("age") e = await w.updates() assert e.bucket == "WATCH" - assert e.key == 'age' - assert e.value == b'' + assert e.key == "age" + assert e.value == b"" assert e.revision == 6 assert e.operation == "DEL" await kv.purge("name") e = await w.updates() assert e.bucket == "WATCH" - assert e.key == 'name' - assert e.value == b'' + assert e.key == "name" + assert e.value == b"" assert e.revision == 7 assert e.operation == "PURGE" @@ -3049,13 +3030,13 @@ async def error_handler(e): await w.stop() # Now try wildcard matching and make sure we only get last value when starting. - await kv.create("new", b'hello world') - await kv.put("t.name", b'a') - await kv.put("t.name", b'b') - await kv.put("t.age", b'c') - await kv.put("t.age", b'd') - await kv.put("t.a", b'a') - await kv.put("t.b", b'b') + await kv.create("new", b"hello world") + await kv.put("t.name", b"a") + await kv.put("t.name", b"b") + await kv.put("t.age", b"c") + await kv.put("t.age", b"d") + await kv.put("t.a", b"a") + await kv.put("t.b", b"b") # Will only get last values of the matching keys. w = await kv.watch("t.*") @@ -3065,7 +3046,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 3 assert e.key == "t.name" - assert e.value == b'b' + assert e.value == b"b" assert e.revision == 10 assert e.operation == None @@ -3073,7 +3054,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 2 assert e.key == "t.age" - assert e.value == b'd' + assert e.value == b"d" assert e.revision == 12 assert e.operation == None @@ -3081,7 +3062,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 1 assert e.key == "t.a" - assert e.value == b'a' + assert e.value == b"a" assert e.revision == 13 assert e.operation == None @@ -3090,7 +3071,7 @@ async def error_handler(e): assert e.bucket == "WATCH" assert e.delta == 0 assert e.key == "t.b" - assert e.value == b'b' + assert e.value == b"b" assert e.revision == 14 assert e.operation == None @@ -3103,10 +3084,10 @@ async def error_handler(e): with pytest.raises(TimeoutError): await w.updates(timeout=1) - await kv.put("t.hello", b'hello world') + await kv.put("t.hello", b"hello world") e = await w.updates() assert e.delta == 0 - assert e.key == 't.hello' + assert e.key == "t.hello" assert e.revision == 15 # Default watch timeout should 5 minutes @@ -3135,14 +3116,14 @@ async def error_handler(e): status = await kv.status() for i in range(0, 50): - await kv.put(f"age", f'{i}'.encode()) + await kv.put(f"age", f"{i}".encode()) vl = await kv.history("age") assert len(vl) == 10 i = 0 for entry in vl: - assert entry.key == 'age' + assert entry.key == "age" assert entry.revision == i + 41 assert int(entry.value) == i + 40 i += 1 @@ -3166,12 +3147,12 @@ async def error_handler(e): with pytest.raises(NoKeysError): await kv.keys() - await kv.put("a", b'1') - await kv.put("b", b'2') - await kv.put("a", b'11') - await kv.put("b", b'22') - await kv.put("a", b'111') - await kv.put("b", b'222') + await kv.put("a", b"1") + await kv.put("b", b"2") + await kv.put("a", b"11") + await kv.put("b", b"22") + await kv.put("a", b"111") + await kv.put("b", b"222") keys = await kv.keys() assert len(keys) == 2 @@ -3190,10 +3171,10 @@ async def error_handler(e): with pytest.raises(NoKeysError): await kv.keys() - await kv.create("c", b'3') + await kv.create("c", b"3") keys = await kv.keys() assert len(keys) == 1 - assert 'c' in keys + assert "c" in keys await nc.close() @@ -3232,7 +3213,7 @@ async def error_handler(e): for i in range(0, 10): await kv.delete(f"key-{i}") - await kv.put(f"key-last", b'101') + await kv.put(f"key-last", b"101") await kv.purge_deletes(olderthan=-1) await asyncio.sleep(0.5) @@ -3281,9 +3262,9 @@ async def error_handler(e): history = await kv.history("bar") assert len(history) == 1 entry = history[0] - assert entry.key == 'bar' + assert entry.key == "bar" assert entry.revision == 6 - assert entry.operation == 'DEL' + assert entry.operation == "DEL" await nc.close() @@ -3299,10 +3280,10 @@ async def error_handler(e): js = nc.jetstream() async def pub(): - await js.publish("foo.A", b'1') - await js.publish("foo.C", b'1') - await js.publish("foo.B", b'1') - await js.publish("foo.C", b'2') + await js.publish("foo.A", b"1") + await js.publish("foo.C", b"1") + await js.publish("foo.B", b"1") + await js.publish("foo.C", b"2") await js.add_stream(name="foo", subjects=["foo.A", "foo.B", "foo.C"]) await pub() @@ -3313,12 +3294,12 @@ async def pub(): assert info.state.messages == 2 msgs = await sub.fetch(5, timeout=1) assert len(msgs) == 2 - assert msgs[0].subject == 'foo.B' - assert msgs[1].subject == 'foo.C' + assert msgs[0].subject == "foo.B" + assert msgs[1].subject == "foo.C" await js.publish( "foo.C", - b'3', - headers={nats.js.api.Header.EXPECTED_LAST_SUBJECT_SEQUENCE: "4"} + b"3", + headers={nats.js.api.Header.EXPECTED_LAST_SUBJECT_SEQUENCE: "4"}, ) await nc.close() @@ -3335,18 +3316,17 @@ async def error_handler(e): js = nc.jetstream() kv = await js.create_key_value( - bucket="TEST_UPDATE", - republish=nats.js.api.RePublish(src=">", dest="bar.>") + bucket="TEST_UPDATE", republish=nats.js.api.RePublish(src=">", dest="bar.>") ) status = await kv.status() sinfo = await js.stream_info("KV_TEST_UPDATE") assert sinfo.config.republish is not None sub = await nc.subscribe("bar.>") - await kv.put("hello.world", b'Hello World!') + await kv.put("hello.world", b"Hello World!") msg = await sub.next_msg() - assert msg.data == b'Hello World!' - assert msg.headers.get('Nats-Msg-Size', None) == None + assert msg.data == b"Hello World!" + assert msg.headers.get("Nats-Msg-Size", None) == None await sub.unsubscribe() kv = await js.create_key_value( @@ -3355,15 +3335,15 @@ async def error_handler(e): src=">", dest="quux.>", headers_only=True, - ) + ), ) sub = await nc.subscribe("quux.>") - await kv.put("hello.world", b'Hello World!') + await kv.put("hello.world", b"Hello World!") msg = await sub.next_msg() - assert msg.data == b'' - assert msg.headers['Nats-Sequence'] == '1' - assert msg.headers['Nats-Msg-Size'] == '12' - assert msg.headers['Nats-Stream'] == 'KV_TEST_UPDATE_HEADERS' + assert msg.data == b"" + assert msg.headers["Nats-Sequence"] == "1" + assert msg.headers["Nats-Msg-Size"] == "12" + assert msg.headers["Nats-Stream"] == "KV_TEST_UPDATE_HEADERS" await sub.unsubscribe() await nc.close() @@ -3382,11 +3362,11 @@ async def error_handler(e): kv = await js.create_key_value(bucket="TEST_LOGGING", history=5, ttl=3600) # Add keys to the bucket - await kv.put("hello", b'world') - await kv.put("greeting", b'hi') + await kv.put("hello", b"world") + await kv.put("greeting", b"hi") # Test with filters (fetch keys with "hello" or "greet") - filtered_keys = await kv.keys(filters=['hello', 'greet']) + filtered_keys = await kv.keys(filters=["hello", "greet"]) assert "hello" in filtered_keys assert "greeting" in filtered_keys @@ -3395,7 +3375,6 @@ async def error_handler(e): class ObjectStoreTest(SingleJetStreamServerTestCase): - @async_test async def test_object_basics(self): errors = [] @@ -3410,9 +3389,7 @@ async def error_handler(e): with pytest.raises(nats.js.errors.InvalidBucketNameError): await js.create_object_store(bucket="notok!") - obs = await js.create_object_store( - bucket="OBJS", description="testing" - ) + obs = await js.create_object_store(bucket="OBJS", description="testing") assert obs._name == "OBJS" assert obs._stream == f"OBJ_OBJS" @@ -3421,7 +3398,7 @@ async def error_handler(e): sinfo = status.stream_info assert sinfo.config.name == "OBJ_OBJS" assert sinfo.config.description == "testing" - assert sinfo.config.subjects == ['$O.OBJS.C.>', '$O.OBJS.M.>'] + assert sinfo.config.subjects == ["$O.OBJS.C.>", "$O.OBJS.M.>"] assert sinfo.config.retention == "limits" assert sinfo.config.max_consumers == -1 assert sinfo.config.max_msgs == -1 @@ -3436,9 +3413,8 @@ async def error_handler(e): assert sinfo.config.allow_direct == True assert sinfo.config.mirror_direct == False - bucketname = ''.join( - random.SystemRandom().choice(string.ascii_letters) - for _ in range(10) + bucketname = "".join( + random.SystemRandom().choice(string.ascii_letters) for _ in range(10) ) obs = await js.create_object_store(bucket=bucketname) assert obs._name == bucketname @@ -3449,13 +3425,13 @@ async def error_handler(e): assert obs._stream == f"OBJ_{bucketname}" # Simple example using bytes. - info = await obs.put("foo", b'bar') + info = await obs.put("foo", b"bar") assert info.name == "foo" assert info.size == 3 assert info.bucket == bucketname # Simple example using strings. - info = await obs.put("plaintext", 'lorem ipsum') + info = await obs.put("plaintext", "lorem ipsum") assert info.name == "plaintext" assert info.size == 11 assert info.bucket == bucketname @@ -3464,16 +3440,16 @@ async def error_handler(e): opts = api.ObjectMetaOptions(max_chunk_size=5) info = await obs.put( "filename.txt", - b'filevalue', + b"filevalue", meta=nats.js.api.ObjectMeta( description="filedescription", headers={"headername": "headerval"}, options=opts, - ) + ), ) filename = "filename.txt" - filevalue = b'filevalue' + filevalue = b"filevalue" filedesc = "filedescription" fileheaders = {"headername": "headerval"} assert info.name == filename @@ -3486,7 +3462,9 @@ async def error_handler(e): h = sha256() h.update(filevalue) h.digest() - expected_digest = f"SHA-256={base64.urlsafe_b64encode(h.digest()).decode('utf-8')}" + expected_digest = ( + f"SHA-256={base64.urlsafe_b64encode(h.digest()).decode('utf-8')}" + ) assert info.digest == expected_digest assert info.deleted == False assert info.description == filedesc @@ -3542,13 +3520,13 @@ async def error_handler(e): # Create an 8MB object. obs = await js.create_object_store(bucket="big") - ls = ''.join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) + ls = "".join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) w = io.BytesIO(ls.encode()) info = await obs.put("big", w) assert info.name == "big" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Create actual file and put it in a bucket. tmp = tempfile.NamedTemporaryFile(delete=False) @@ -3560,14 +3538,14 @@ async def error_handler(e): assert info.name == "tmp" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" with open(tmp.name) as f: info = await obs.put("tmp2", f) assert info.name == "tmp2" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # By default this reads the complete data. obr = await obs.get("tmp") @@ -3575,7 +3553,7 @@ async def error_handler(e): assert info.name == "tmp" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Using a local file. with open("pyproject.toml") as f: @@ -3592,19 +3570,25 @@ async def error_handler(e): # Write into file without getting complete data. w = tempfile.NamedTemporaryFile(delete=False) w.close() - with open(w.name, 'w') as f: + with open(w.name, "w") as f: obr = await obs.get("tmp", writeinto=f) - assert obr.data == b'' + assert obr.data == b"" assert obr.info.size == 1048609 - assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert ( + obr.info.digest + == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" + ) w2 = tempfile.NamedTemporaryFile(delete=False) w2.close() - with open(w2.name, 'w') as f: + with open(w2.name, "w") as f: obr = await obs.get("tmp", writeinto=f.buffer) - assert obr.data == b'' + assert obr.data == b"" assert obr.info.size == 1048609 - assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert ( + obr.info.digest + == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" + ) with open(w2.name) as f: result = f.read(-1) @@ -3625,7 +3609,7 @@ async def error_handler(e): # Create an 8MB object. obs = await js.create_object_store(bucket="sample") - ls = ''.join("A" for _ in range(0, 2 * 1024 * 1024 + 33)) + ls = "".join("A" for _ in range(0, 2 * 1024 * 1024 + 33)) w = io.BytesIO(ls.encode()) info = await obs.put("sample", w) assert info.name == "sample" @@ -3680,22 +3664,24 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_FILES", - config=nats.js.api.ObjectStoreConfig(description="multi_files", ) + config=nats.js.api.ObjectStoreConfig( + description="multi_files", + ), ) - await obs.put("A", b'A') - await obs.put("B", b'B') - await obs.put("C", b'C') + await obs.put("A", b"A") + await obs.put("B", b"B") + await obs.put("C", b"C") res = await obs.get("A") - assert res.data == b'A' + assert res.data == b"A" assert res.info.digest == "SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=" res = await obs.get("B") - assert res.data == b'B' + assert res.data == b"B" assert res.info.digest == "SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=" res = await obs.get("C") - assert res.data == b'C' + assert res.data == b"C" assert res.info.digest == "SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=" with open("README.md") as fp: @@ -3724,49 +3710,51 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_FILES", - config=nats.js.api.ObjectStoreConfig(description="multi_files", ) + config=nats.js.api.ObjectStoreConfig( + description="multi_files", + ), ) watcher = await obs.watch() e = await watcher.updates() assert e == None - await obs.put("A", b'A') - await obs.put("B", b'B') - await obs.put("C", b'C') + await obs.put("A", b"A") + await obs.put("B", b"B") + await obs.put("C", b"C") res = await obs.get("A") - assert res.data == b'A' + assert res.data == b"A" assert res.info.digest == "SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=" res = await obs.get("B") - assert res.data == b'B' + assert res.data == b"B" assert res.info.digest == "SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=" res = await obs.get("C") - assert res.data == b'C' + assert res.data == b"C" assert res.info.digest == "SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=" e = await watcher.updates() - assert e.name == 'A' - assert e.bucket == 'TEST_FILES' + assert e.name == "A" + assert e.bucket == "TEST_FILES" assert e.size == 1 assert e.chunks == 1 - assert e.digest == 'SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=' + assert e.digest == "SHA-256=VZrq0IJk1XldOQlxjN0Fq9SVcuhP5VWQ7vMaiKCP3_0=" e = await watcher.updates() - assert e.name == 'B' - assert e.bucket == 'TEST_FILES' + assert e.name == "B" + assert e.bucket == "TEST_FILES" assert e.size == 1 assert e.chunks == 1 - assert e.digest == 'SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=' + assert e.digest == "SHA-256=335w5QIVRPSDS77mSp43if68S-gUcN9inK1t2wMyClw=" e = await watcher.updates() - assert e.name == 'C' - assert e.bucket == 'TEST_FILES' + assert e.name == "C" + assert e.bucket == "TEST_FILES" assert e.size == 1 assert e.chunks == 1 - assert e.digest == 'SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=' + assert e.digest == "SHA-256=ayPA1fNdGxH5toPwsKYXNV3rESd9ka4JHTmcZVuHlA0=" # Expect no more updates. with pytest.raises(asyncio.TimeoutError): @@ -3786,16 +3774,16 @@ async def error_handler(e): assert deleted.info.name == "B" assert deleted.info.deleted == True - info = await obs.put("C", b'CCC') - assert info.name == 'C' + info = await obs.put("C", b"CCC") + assert info.name == "C" assert info.deleted == False e = await watcher.updates() - assert e.name == 'C' + assert e.name == "C" assert e.deleted == False res = await obs.get("C") - assert res.data == b'CCC' + assert res.data == b"CCC" # Update meta to_update_meta = res.info.meta @@ -3803,8 +3791,8 @@ async def error_handler(e): await obs.update_meta("C", to_update_meta) e = await watcher.updates() - assert e.name == 'C' - assert e.description == 'changed' + assert e.name == "C" + assert e.description == "changed" # Try to update meta when it has already been deleted. deleted_meta = deleted.info.meta @@ -3818,7 +3806,7 @@ async def error_handler(e): # Update meta res = await obs.get("A") - assert res.data == b'A' + assert res.data == b"A" to_update_meta = res.info.meta to_update_meta.name = "Z" to_update_meta.description = "changed" @@ -3840,20 +3828,22 @@ async def error_handler(e): obs = await js.create_object_store( "TEST_LIST", - config=nats.js.api.ObjectStoreConfig(description="listing", ) + config=nats.js.api.ObjectStoreConfig( + description="listing", + ), ) await asyncio.gather( - obs.put("A", b'AAA'), - obs.put("B", b'BBB'), - obs.put("C", b'CCC'), - obs.put("D", b'DDD'), + obs.put("A", b"AAA"), + obs.put("B", b"BBB"), + obs.put("C", b"CCC"), + obs.put("D", b"DDD"), ) entries = await obs.list() assert len(entries) == 4 - assert entries[0].name == 'A' - assert entries[1].name == 'B' - assert entries[2].name == 'C' - assert entries[3].name == 'D' + assert entries[0].name == "A" + assert entries[1].name == "B" + assert entries[2].name == "C" + assert entries[3].name == "D" await nc.close() @@ -3875,13 +3865,13 @@ async def error_handler(e): # Create an 8MB object. obs = await js.create_object_store(bucket="big") - ls = ''.join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) + ls = "".join("A" for _ in range(0, 1 * 1024 * 1024 + 33)) w = io.BytesIO(ls.encode()) info = await obs.put("big", w) assert info.name == "big" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Create actual file and put it in a bucket. tmp = tempfile.NamedTemporaryFile(delete=False) @@ -3893,14 +3883,14 @@ async def error_handler(e): assert info.name == "tmp" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" async with aiofiles.open(tmp.name) as f: info = await obs.put("tmp2", f) assert info.name == "tmp2" assert info.size == 1048609 assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" obr = await obs.get("tmp") info = obr.info @@ -3908,7 +3898,7 @@ async def error_handler(e): assert info.size == 1048609 assert len(obr.data) == info.size # no reader reads whole file. assert info.chunks == 9 - assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=' + assert info.digest == "SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA=" # Using a local file. async with aiofiles.open("pyproject.toml") as f: @@ -3925,7 +3915,6 @@ async def error_handler(e): class ConsumerReplicasTest(SingleJetStreamServerTestCase): - @async_test async def test_number_of_consumer_replicas(self): nc = await nats.connect() @@ -3933,12 +3922,10 @@ async def test_number_of_consumer_replicas(self): js = nc.jetstream() await js.add_stream(name="TESTREPLICAS", subjects=["test.replicas"]) for i in range(0, 10): - await js.publish("test.replicas", f'{i}'.encode()) + await js.publish("test.replicas", f"{i}".encode()) # Create consumer - config = nats.js.api.ConsumerConfig( - num_replicas=1, durable_name="mycons" - ) + config = nats.js.api.ConsumerConfig(num_replicas=1, durable_name="mycons") cons = await js.add_consumer(stream="TESTREPLICAS", config=config) assert cons.config.num_replicas == 1 @@ -3947,7 +3934,6 @@ async def test_number_of_consumer_replicas(self): class AccountLimitsTest(SingleJetStreamServerLimitsTestCase): - @async_test async def test_account_limits(self): nc = await nats.connect() @@ -3957,29 +3943,30 @@ async def test_account_limits(self): with pytest.raises(BadRequestError) as err: await js.add_stream(name="limits", subjects=["limits"]) assert err.value.err_code == 10113 - assert err.value.description == "account requires a stream config to have max bytes set" + assert ( + err.value.description + == "account requires a stream config to have max bytes set" + ) with pytest.raises(BadRequestError) as err: - await js.add_stream( - name="limits", subjects=["limits"], max_bytes=65536 - ) + await js.add_stream(name="limits", subjects=["limits"], max_bytes=65536) assert err.value.err_code == 10122 - assert err.value.description == "stream max bytes exceeds account limit max stream bytes" - - si = await js.add_stream( - name="limits", subjects=["limits"], max_bytes=128 + assert ( + err.value.description + == "stream max bytes exceeds account limit max stream bytes" ) + + si = await js.add_stream(name="limits", subjects=["limits"], max_bytes=128) assert si.config.max_bytes == 128 await js.publish( - "limits", - b'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' + "limits", b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" ) si = await js.stream_info("limits") assert si.state.messages == 1 for i in range(0, 5): - await js.publish("limits", b'A') + await js.publish("limits", b"A") expected = nats.js.api.AccountInfo( memory=0, @@ -3994,11 +3981,11 @@ async def test_account_limits(self): max_ack_pending=100, memory_max_stream_bytes=2048, storage_max_stream_bytes=4096, - max_bytes_required=True + max_bytes_required=True, ), api=nats.js.api.APIStats(total=4, errors=2), - domain='test-domain', - tiers=None + domain="test-domain", + tiers=None, ) info = await js.account_info() assert expected == info @@ -4079,46 +4066,44 @@ async def test_account_limits(self): max_ack_pending=0, memory_max_stream_bytes=0, storage_max_stream_bytes=0, - max_bytes_required=False + max_bytes_required=False, ), api=nats.js.api.APIStats(total=6, errors=0), - domain='ngs', + domain="ngs", tiers={ - 'R1': - nats.js.api.Tier( - memory=0, - storage=6829550, - streams=1, - consumers=0, - limits=nats.js.api.AccountLimits( - max_memory=0, - max_storage=2000000000000, - max_streams=100, - max_consumers=1000, - max_ack_pending=-1, - memory_max_stream_bytes=-1, - storage_max_stream_bytes=-1, - max_bytes_required=True - ) + "R1": nats.js.api.Tier( + memory=0, + storage=6829550, + streams=1, + consumers=0, + limits=nats.js.api.AccountLimits( + max_memory=0, + max_storage=2000000000000, + max_streams=100, + max_consumers=1000, + max_ack_pending=-1, + memory_max_stream_bytes=-1, + storage_max_stream_bytes=-1, + max_bytes_required=True, + ), + ), + "R3": nats.js.api.Tier( + memory=0, + storage=0, + streams=0, + consumers=0, + limits=nats.js.api.AccountLimits( + max_memory=0, + max_storage=500000000000, + max_streams=25, + max_consumers=250, + max_ack_pending=-1, + memory_max_stream_bytes=-1, + storage_max_stream_bytes=-1, + max_bytes_required=True, ), - 'R3': - nats.js.api.Tier( - memory=0, - storage=0, - streams=0, - consumers=0, - limits=nats.js.api.AccountLimits( - max_memory=0, - max_storage=500000000000, - max_streams=25, - max_consumers=250, - max_ack_pending=-1, - memory_max_stream_bytes=-1, - storage_max_stream_bytes=-1, - max_bytes_required=True - ) - ) - } + ), + }, ) info = nats.js.api.AccountInfo.from_response(json.loads(blob)) assert expected == info @@ -4126,7 +4111,6 @@ async def test_account_limits(self): class V210FeaturesTest(SingleJetStreamServerTestCase): - @async_test async def test_subject_transforms(self): nc = await nats.connect() @@ -4142,7 +4126,7 @@ async def test_subject_transforms(self): ), ) for i in range(0, 10): - await js.publish("test", f'{i}'.encode()) + await js.publish("test", f"{i}".encode()) # Creating a filtered consumer will be successful but will not be able to match # so num_pending would remain at zeroes. @@ -4164,10 +4148,10 @@ async def test_subject_transforms(self): assert cinfo.num_pending == 10 msgs = await psub.fetch(10, timeout=1) assert len(msgs) == 10 - assert msgs[0].subject == 'transformed.test' - assert msgs[0].data == b'0' - assert msgs[5].subject == 'transformed.test' - assert msgs[5].data == b'5' + assert msgs[0].subject == "transformed.test" + assert msgs[0].data == b"0" + assert msgs[5].subject == "transformed.test" + assert msgs[5].data == b"5" # Create stream that fetches from the original stream that # already has the transformed subjects. @@ -4178,9 +4162,7 @@ async def test_subject_transforms(self): nats.js.api.SubjectTransform( src="transformed.>", dest="fromtest.transformed.>" ), - nats.js.api.SubjectTransform( - src="foo.>", dest="fromtest.foo.>" - ), + nats.js.api.SubjectTransform(src="foo.>", dest="fromtest.foo.>"), ], ) await js.add_stream( @@ -4198,18 +4180,16 @@ async def test_subject_transforms(self): assert cinfo.num_pending == 10 msgs = await psub.fetch(10, timeout=1) assert len(msgs) == 10 - assert msgs[0].subject == 'fromtest.transformed.test' - assert msgs[0].data == b'0' - assert msgs[5].subject == 'fromtest.transformed.test' - assert msgs[5].data == b'5' + assert msgs[0].subject == "fromtest.transformed.test" + assert msgs[0].data == b"0" + assert msgs[5].subject == "fromtest.transformed.test" + assert msgs[5].data == b"5" # Source is overlapping so should fail. transformed_source = nats.js.api.StreamSource( name="TRANSFORMS2", subject_transforms=[ - nats.js.api.SubjectTransform( - src=">", dest="fromtest.transformed.>" - ), + nats.js.api.SubjectTransform(src=">", dest="fromtest.transformed.>"), nats.js.api.SubjectTransform(src=">", dest="fromtest.foo.>"), ], ) @@ -4241,7 +4221,7 @@ async def test_stream_compression(self): subjects=["test", "foo"], compression="s3", ) - assert str(err.value) == 'nats: invalid store compression type: s3' + assert str(err.value) == "nats: invalid store compression type: s3" # An empty string means not setting compression, but to be explicit # can also use @@ -4272,10 +4252,10 @@ async def test_stream_consumer_metadata(self): await js.add_stream( name="META", subjects=["test", "foo"], - metadata={'foo': 'bar'}, + metadata={"foo": "bar"}, ) sinfo = await js.stream_info("META") - assert sinfo.config.metadata['foo'] == 'bar' + assert sinfo.config.metadata["foo"] == "bar" with pytest.raises(ValueError) as err: await js.add_stream( @@ -4283,16 +4263,16 @@ async def test_stream_consumer_metadata(self): subjects=["test", "foo"], metadata=["hello", "world"], ) - assert str(err.value) == 'nats: invalid metadata format' + assert str(err.value) == "nats: invalid metadata format" await js.add_consumer( "META", config=nats.js.api.ConsumerConfig( - durable_name="b", metadata={'hello': 'world'} - ) + durable_name="b", metadata={"hello": "world"} + ), ) cinfo = await js.consumer_info("META", "b") - assert cinfo.config.metadata['hello'] == 'world' + assert cinfo.config.metadata["hello"] == "world" await nc.close() @@ -4307,7 +4287,7 @@ async def test_fetch_pull_subscribe_bind(self): await js.add_stream(name=stream_name, subjects=["foo", "bar"]) for i in range(0, 5): - await js.publish("foo", b'A') + await js.publish("foo", b"A") # Fetch with multiple filters on an ephemeral consumer. cinfo = await js.add_consumer( @@ -4317,9 +4297,7 @@ async def test_fetch_pull_subscribe_bind(self): ) # Using named arguments. - psub = await js.pull_subscribe_bind( - stream=stream_name, consumer=cinfo.name - ) + psub = await js.pull_subscribe_bind(stream=stream_name, consumer=cinfo.name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() @@ -4331,17 +4309,13 @@ async def test_fetch_pull_subscribe_bind(self): await msg.ack() # Using durable argument to refer to ephemeral is ok for backwards compatibility. - psub = await js.pull_subscribe_bind( - durable=cinfo.name, stream=stream_name - ) + psub = await js.pull_subscribe_bind(durable=cinfo.name, stream=stream_name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() # stream, consumer name order - psub = await js.pull_subscribe_bind( - stream=stream_name, durable=cinfo.name - ) + psub = await js.pull_subscribe_bind(stream=stream_name, durable=cinfo.name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() @@ -4349,9 +4323,7 @@ async def test_fetch_pull_subscribe_bind(self): assert msg.metadata.num_pending == 1 # name can also be used to refer to the consumer name - psub = await js.pull_subscribe_bind( - stream=stream_name, name=cinfo.name - ) + psub = await js.pull_subscribe_bind(stream=stream_name, name=cinfo.name) msgs = await psub.fetch(1) msg = msgs[0] await msg.ack() @@ -4361,11 +4333,11 @@ async def test_fetch_pull_subscribe_bind(self): with pytest.raises(ValueError) as err: await js.pull_subscribe_bind(durable=cinfo.name) - assert str(err.value) == 'nats: stream name is required' + assert str(err.value) == "nats: stream name is required" with pytest.raises(ValueError) as err: await js.pull_subscribe_bind(cinfo.name) - assert str(err.value) == 'nats: stream name is required' + assert str(err.value) == "nats: stream name is required" await nc.close() diff --git a/tests/test_micro_service.py b/tests/test_micro_service.py index c1227222..059e8399 100644 --- a/tests/test_micro_service.py +++ b/tests/test_micro_service.py @@ -17,16 +17,20 @@ def test_invalid_service_name(self): ServiceConfig(name="", version="0.1.0") self.assertEqual(str(context.exception), "Name cannot be empty.") - with self.assertRaises(ValueError) as context: ServiceConfig(name="test.service@!", version="0.1.0") - self.assertEqual(str(context.exception), "Invalid name. It must contain only alphanumeric characters, dashes, and underscores.") - + self.assertEqual( + str(context.exception), + "Invalid name. It must contain only alphanumeric characters, dashes, and underscores.", + ) def test_invalid_service_version(self): with self.assertRaises(ValueError) as context: ServiceConfig(name="test_service", version="abc") - self.assertEqual(str(context.exception), "Invalid version. It must follow semantic versioning (e.g., 1.0.0, 2.1.3-alpha.1).") + self.assertEqual( + str(context.exception), + "Invalid version. It must follow semantic versioning (e.g., 1.0.0, 2.1.3-alpha.1).", + ) def test_invalid_endpoint_subject(self): async def noop_handler(request: Request) -> None: @@ -38,7 +42,10 @@ async def noop_handler(request: Request) -> None: subject="endpoint subject", handler=noop_handler, ) - self.assertEqual(str(context.exception), "Invalid subject. Subject must not contain spaces, and can only have '>' at the end.") + self.assertEqual( + str(context.exception), + "Invalid subject. Subject must not contain spaces, and can only have '>' at the end.", + ) @async_test async def test_service_basics(self): @@ -217,7 +224,6 @@ async def noop_handler(request: Request): await svc.stop() assert svc.stopped - @async_test async def test_groups(self): sub_tests = { diff --git a/tests/test_nuid.py b/tests/test_nuid.py index 1220c8af..fd05a0f2 100644 --- a/tests/test_nuid.py +++ b/tests/test_nuid.py @@ -20,7 +20,6 @@ class NUIDTest(unittest.TestCase): - def setUp(self): super().setUp() @@ -32,18 +31,14 @@ def test_nuid_are_unique(self): nuid = NUID() entries = [nuid.next().decode() for i in range(500000)] counted_entries = Counter(entries) - repeated = [ - entry for entry, count in counted_entries.items() if count > 1 - ] + repeated = [entry for entry, count in counted_entries.items() if count > 1] self.assertEqual(len(repeated), 0) def test_nuid_are_very_unique(self): nuid = NUID() entries = [nuid.next().decode() for i in range(1000000)] counted_entries = Counter(entries) - repeated = [ - entry for entry, count in counted_entries.items() if count > 1 - ] + repeated = [entry for entry, count in counted_entries.items() if count > 1] self.assertEqual(len(repeated), 0) def test_subsequent_nuid_equal(self): diff --git a/tests/test_parser.py b/tests/test_parser.py index d5fda2d5..2873575c 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -9,7 +9,6 @@ class MockNatsClient: - def __init__(self): self._subs = {} self._pongs = [] @@ -39,14 +38,13 @@ def _process_info(self, info): class ProtocolParserTest(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() @async_test async def test_parse_ping(self): ps = Parser(MockNatsClient()) - data = b'PING\r\n' + data = b"PING\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -54,7 +52,7 @@ async def test_parse_ping(self): @async_test async def test_parse_pong(self): ps = Parser(MockNatsClient()) - data = b'PONG\r\n' + data = b"PONG\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -62,7 +60,7 @@ async def test_parse_pong(self): @async_test async def test_parse_ok(self): ps = Parser() - data = b'+OK\r\n' + data = b"+OK\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -70,7 +68,7 @@ async def test_parse_ok(self): @async_test async def test_parse_msg(self): nc = MockNatsClient() - expected = b'hello world!' + expected = b"hello world!" async def payload_test(sid, subject, reply, payload): self.assertEqual(payload, expected) @@ -84,22 +82,22 @@ async def payload_test(sid, subject, reply, payload): sub = Subscription(nc, 1, **params) nc._subs[1] = sub ps = Parser(nc) - data = b'MSG hello 1 world 12\r\n' + data = b"MSG hello 1 world 12\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(len(ps.msg_arg.keys()), 3) - self.assertEqual(ps.msg_arg["subject"], b'hello') - self.assertEqual(ps.msg_arg["reply"], b'world') + self.assertEqual(ps.msg_arg["subject"], b"hello") + self.assertEqual(ps.msg_arg["reply"], b"world") self.assertEqual(ps.msg_arg["sid"], 1) self.assertEqual(ps.needed, 12) self.assertEqual(ps.state, AWAITING_MSG_PAYLOAD) - await ps.parse(b'hello world!') + await ps.parse(b"hello world!") self.assertEqual(len(ps.buf), 12) self.assertEqual(ps.state, AWAITING_MSG_PAYLOAD) - data = b'\r\n' + data = b"\r\n" await ps.parse(data) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -107,7 +105,7 @@ async def payload_test(sid, subject, reply, payload): @async_test async def test_parse_msg_op(self): ps = Parser() - data = b'MSG hello' + data = b"MSG hello" await ps.parse(data) self.assertEqual(len(ps.buf), 9) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -115,7 +113,7 @@ async def test_parse_msg_op(self): @async_test async def test_parse_split_msg_op(self): ps = Parser() - data = b'MSG' + data = b"MSG" await ps.parse(data) self.assertEqual(len(ps.buf), 3) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -123,7 +121,7 @@ async def test_parse_split_msg_op(self): @async_test async def test_parse_split_msg_op_space(self): ps = Parser() - data = b'MSG ' + data = b"MSG " await ps.parse(data) self.assertEqual(len(ps.buf), 4) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) @@ -131,7 +129,7 @@ async def test_parse_split_msg_op_space(self): @async_test async def test_parse_split_msg_op_wrong_args(self): ps = Parser(MockNatsClient()) - data = b'MSG PONG\r\n' + data = b"MSG PONG\r\n" with self.assertRaises(ProtocolError): await ps.parse(data) @@ -160,12 +158,12 @@ async def test_parse_err(self): async def test_parse_info(self): nc = MockNatsClient() ps = Parser(nc) - server_id = 'A' * 2048 + server_id = "A" * 2048 data = f'INFO {{"server_id": "{server_id}", "max_payload": 100, "auth_required": false, "connect_urls":["127.0.0.0.1:4223"]}}\r\n' await ps.parse(data.encode()) self.assertEqual(len(ps.buf), 0) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - self.assertEqual(len(nc._server_info['server_id']), 2048) + self.assertEqual(len(nc._server_info["server_id"]), 2048) @async_test async def test_parse_msg_long_subject_reply(self): @@ -187,14 +185,14 @@ async def payload_test(sid, subject, reply, payload): nc._subs[1] = sub ps = Parser(nc) - reply = 'A' * 2043 - data = f'PING\r\nMSG hello 1 {reply}' + reply = "A" * 2043 + data = f"PING\r\nMSG hello 1 {reply}" await ps.parse(data.encode()) - await ps.parse(b'AAAAA 0\r\n\r\nMSG hello 1 world 0') + await ps.parse(b"AAAAA 0\r\n\r\nMSG hello 1 world 0") self.assertEqual(msgs, 1) self.assertEqual(len(ps.buf), 19) self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - await ps.parse(b'\r\n\r\n') + await ps.parse(b"\r\n\r\n") self.assertEqual(msgs, 2) @async_test @@ -217,24 +215,24 @@ async def payload_test(sid, subject, reply, payload): nc._subs[1] = sub ps = Parser(nc) - reply = 'A' * 2043 * 4 + reply = "A" * 2043 * 4 # FIXME: Malformed long protocol lines will not be detected # by the client in this case so we rely on the ping/pong interval # from the client to give up instead. - data = f'PING\r\nWRONG hello 1 {reply}' + data = f"PING\r\nWRONG hello 1 {reply}" await ps.parse(data.encode()) - await ps.parse(b'AAAAA 0') + await ps.parse(b"AAAAA 0") self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - await ps.parse(b'\r\n\r\n') - await ps.parse(b'\r\n\r\n') + await ps.parse(b"\r\n\r\n") + await ps.parse(b"\r\n\r\n") ps = Parser(nc) - reply = 'A' * 2043 - data = f'PING\r\nWRONG hello 1 {reply}' + reply = "A" * 2043 + data = f"PING\r\nWRONG hello 1 {reply}" with self.assertRaises(ProtocolError): await ps.parse(data.encode()) - await ps.parse(b'AAAAA 0') + await ps.parse(b"AAAAA 0") self.assertEqual(ps.state, AWAITING_CONTROL_LINE) - await ps.parse(b'\r\n\r\n') - await ps.parse(b'\r\n\r\n') + await ps.parse(b"\r\n\r\n") + await ps.parse(b"\r\n\r\n") diff --git a/tests/utils.py b/tests/utils.py index 43dd27fa..4030c599 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,7 @@ try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: pass @@ -22,7 +23,6 @@ class NATSD: - def __init__( self, port=4222, @@ -62,8 +62,7 @@ def start(self): if Path(self.bin_name).is_file(): self.bin_name = Path(self.bin_name).absolute() # Path in `../scripts/install_nats.sh` - elif Path.home().joinpath(SERVER_BIN_DIR_NAME, - self.bin_name).is_file(): + elif Path.home().joinpath(SERVER_BIN_DIR_NAME, self.bin_name).is_file(): self.bin_name = str( Path.home().joinpath(SERVER_BIN_DIR_NAME, self.bin_name) ) @@ -72,9 +71,13 @@ def start(self): self.bin_name = str(Path(THIS_DIR).joinpath(self.bin_name)) cmd = [ - self.bin_name, "-p", - "%d" % self.port, "-m", - "%d" % self.http_port, "-a", "127.0.0.1" + self.bin_name, + "-p", + "%d" % self.port, + "-m", + "%d" % self.http_port, + "-a", + "127.0.0.1", ] if self.user: cmd.append("--user") @@ -95,24 +98,24 @@ def start(self): cmd.append(f"-sd={self.store_dir}") if self.tls: - cmd.append('--tls') - cmd.append('--tlscert') - cmd.append(get_config_file('certs/server-cert.pem')) - cmd.append('--tlskey') - cmd.append(get_config_file('certs/server-key.pem')) - cmd.append('--tlsverify') - cmd.append('--tlscacert') - cmd.append(get_config_file('certs/ca.pem')) + cmd.append("--tls") + cmd.append("--tlscert") + cmd.append(get_config_file("certs/server-cert.pem")) + cmd.append("--tlskey") + cmd.append(get_config_file("certs/server-key.pem")) + cmd.append("--tlsverify") + cmd.append("--tlscacert") + cmd.append(get_config_file("certs/ca.pem")) if self.cluster_listen is not None: - cmd.append('--cluster_listen') + cmd.append("--cluster_listen") cmd.append(self.cluster_listen) if len(self.routes) > 0: - cmd.append('--routes') - cmd.append(','.join(self.routes)) - cmd.append('--cluster_name') - cmd.append('CLUSTER') + cmd.append("--routes") + cmd.append(",".join(self.routes)) + cmd.append("--cluster_name") + cmd.append("CLUSTER") if self.config_file is not None: cmd.append("--config") @@ -155,8 +158,9 @@ def stop(self): if self.proc.returncode is not None: if self.debug: print( - "[\033[0;31mDEBUG\033[0;0m] Server listening on port {port} finished running already with exit {ret}" - .format(port=self.port, ret=self.proc.returncode) + "[\033[0;31mDEBUG\033[0;0m] Server listening on port {port} finished running already with exit {ret}".format( + port=self.port, ret=self.proc.returncode + ) ) else: os.kill(self.proc.pid, signal.SIGKILL) @@ -169,7 +173,6 @@ def stop(self): class SingleServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -186,7 +189,6 @@ def tearDown(self): class MultiServerAuthTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] @@ -194,9 +196,7 @@ def setUp(self): server1 = NATSD(port=4223, user="foo", password="bar", http_port=8223) self.server_pool.append(server1) - server2 = NATSD( - port=4224, user="hoge", password="fuga", http_port=8224 - ) + server2 = NATSD(port=4224, user="hoge", password="fuga", http_port=8224) self.server_pool.append(server2) for natsd in self.server_pool: start_natsd(natsd) @@ -208,7 +208,6 @@ def tearDown(self): class MultiServerAuthTokenTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] @@ -230,7 +229,6 @@ def tearDown(self): class TLSServerTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() @@ -238,14 +236,12 @@ def setUp(self): self.natsd = NATSD(port=4224, tls=True) start_natsd(self.natsd) - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) def tearDown(self): @@ -254,25 +250,21 @@ def tearDown(self): class TLSServerHandshakeFirstTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() self.natsd = NATSD( - port=4224, - config_file=get_config_file('conf/tls_handshake_first.conf') + port=4224, config_file=get_config_file("conf/tls_handshake_first.conf") ) start_natsd(self.natsd) - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) def tearDown(self): @@ -281,15 +273,12 @@ def tearDown(self): class MultiTLSServerAuthTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] self.loop = asyncio.new_event_loop() - server1 = NATSD( - port=4223, user="foo", password="bar", http_port=8223, tls=True - ) + server1 = NATSD(port=4223, user="foo", password="bar", http_port=8223, tls=True) self.server_pool.append(server1) server2 = NATSD( port=4224, user="hoge", password="fuga", http_port=8224, tls=True @@ -298,14 +287,12 @@ def setUp(self): for natsd in self.server_pool: start_natsd(natsd) - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) def tearDown(self): @@ -315,21 +302,21 @@ def tearDown(self): class ClusteringTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] self.loop = asyncio.new_event_loop() routes = [ - "nats://127.0.0.1:6223", "nats://127.0.0.1:6224", - "nats://127.0.0.1:6225" + "nats://127.0.0.1:6223", + "nats://127.0.0.1:6224", + "nats://127.0.0.1:6225", ] server1 = NATSD( port=4223, http_port=8223, cluster_listen="nats://127.0.0.1:6223", - routes=routes + routes=routes, ) self.server_pool.append(server1) @@ -337,7 +324,7 @@ def setUp(self): port=4224, http_port=8224, cluster_listen="nats://127.0.0.1:6224", - routes=routes + routes=routes, ) self.server_pool.append(server2) @@ -345,7 +332,7 @@ def setUp(self): port=4225, http_port=8225, cluster_listen="nats://127.0.0.1:6225", - routes=routes + routes=routes, ) self.server_pool.append(server3) @@ -364,7 +351,6 @@ def tearDown(self): class ClusteringDiscoveryAuthTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] @@ -377,7 +363,7 @@ def setUp(self): password="bar", http_port=8223, cluster_listen="nats://127.0.0.1:6223", - routes=routes + routes=routes, ) self.server_pool.append(server1) @@ -387,7 +373,7 @@ def setUp(self): password="bar", http_port=8224, cluster_listen="nats://127.0.0.1:6224", - routes=routes + routes=routes, ) self.server_pool.append(server2) @@ -397,7 +383,7 @@ def setUp(self): password="bar", http_port=8225, cluster_listen="nats://127.0.0.1:6225", - routes=routes + routes=routes, ) self.server_pool.append(server3) @@ -414,7 +400,6 @@ def tearDown(self): class NkeysServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -433,15 +418,13 @@ def tearDown(self): class TrustedServerTestCase(unittest.TestCase): - def setUp(self): super().setUp() self.server_pool = [] self.loop = asyncio.new_event_loop() server = NATSD( - port=4222, - config_file=(get_config_file("nkeys/resolver_preload.conf")) + port=4222, config_file=(get_config_file("nkeys/resolver_preload.conf")) ) self.server_pool.append(server) for natsd in self.server_pool: @@ -454,7 +437,6 @@ def tearDown(self): class SingleJetStreamServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -472,7 +454,6 @@ def tearDown(self): class SingleWebSocketServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -489,24 +470,19 @@ def tearDown(self): class SingleWebSocketTLSServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() - self.ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH - ) - self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) + self.ssl_ctx.load_verify_locations(get_config_file("certs/ca.pem")) self.ssl_ctx.load_cert_chain( - certfile=get_config_file('certs/client-cert.pem'), - keyfile=get_config_file('certs/client-key.pem') + certfile=get_config_file("certs/client-cert.pem"), + keyfile=get_config_file("certs/client-key.pem"), ) server = NATSD( - port=4222, - tls=True, - config_file=get_config_file("conf/ws_tls.conf") + port=4222, tls=True, config_file=get_config_file("conf/ws_tls.conf") ) self.server_pool.append(server) for natsd in self.server_pool: @@ -519,7 +495,6 @@ def tearDown(self): class SingleJetStreamServerLimitsTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() @@ -527,7 +502,7 @@ def setUp(self): server = NATSD( port=4222, with_jetstream=True, - config_file=get_config_file("conf/js-limits.conf") + config_file=get_config_file("conf/js-limits.conf"), ) self.server_pool.append(server) for natsd in self.server_pool: @@ -541,14 +516,11 @@ def tearDown(self): class NoAuthUserServerTestCase(unittest.TestCase): - def setUp(self): self.server_pool = [] self.loop = asyncio.new_event_loop() - server = NATSD( - port=4555, config_file=get_config_file("conf/no_auth_user.conf") - ) + server = NATSD(port=4555, config_file=get_config_file("conf/no_auth_user.conf")) self.server_pool.append(server) for natsd in self.server_pool: start_natsd(natsd) @@ -562,7 +534,7 @@ def tearDown(self): def start_natsd(natsd: NATSD): natsd.start() - endpoint = f'127.0.0.1:{natsd.http_port}' + endpoint = f"127.0.0.1:{natsd.http_port}" retries = 0 while True: if retries > 100: @@ -570,7 +542,7 @@ def start_natsd(natsd: NATSD): try: httpclient = http.client.HTTPConnection(endpoint, timeout=5) - httpclient.request('GET', '/varz') + httpclient.request("GET", "/varz") response = httpclient.getresponse() if response.status == 200: break @@ -584,7 +556,6 @@ def get_config_file(file_path): def async_test(test_case_fun, timeout=5): - @wraps(test_case_fun) def wrapper(test_case, *args, **kw): asyncio.set_event_loop(test_case.loop) @@ -596,7 +567,6 @@ def wrapper(test_case, *args, **kw): def async_long_test(test_case_fun, timeout=20): - @wraps(test_case_fun) def wrapper(test_case, *args, **kw): asyncio.set_event_loop(test_case.loop) @@ -606,8 +576,8 @@ def wrapper(test_case, *args, **kw): return wrapper -def async_debug_test(test_case_fun, timeout=3600): +def async_debug_test(test_case_fun, timeout=3600): @wraps(test_case_fun) def wrapper(test_case, *args, **kw): asyncio.set_event_loop(test_case.loop)