Skip to content

Commit

Permalink
Merge pull request #453 from nats-io/aiofiles
Browse files Browse the repository at this point in the history
Use executor when writing/reading files from object store
  • Loading branch information
wallyqs authored Jun 3, 2023
2 parents b51f391 + 899d715 commit 533bd8f
Show file tree
Hide file tree
Showing 7 changed files with 604 additions and 421 deletions.
2 changes: 2 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ pytest = "*"
pytest-cov = "*"
yapf = "*"
toml = "*" # see https://github.com/google/yapf/issues/936
exceptiongroup = "*"

[packages]
nkeys = "*"
aiohttp = "*"
fast-mail-parser = "*"
827 changes: 455 additions & 372 deletions Pipfile.lock

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion nats/aio/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2016-2022 The NATS Authors
# Copyright 2016-2023 The NATS Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -21,6 +21,7 @@
import logging
import ssl
import time
import string
from dataclasses import dataclass
from email.parser import BytesParser
from random import shuffle
Expand Down Expand Up @@ -1636,6 +1637,15 @@ async def _process_headers(self, headers) -> dict[str, str] | None:
hdr.update(parsed_hdr)
else:
hdr = parsed_hdr

if parse_email:
to_delete = []
for k in hdr.keys():
if any(c in k for c in string.whitespace):
to_delete.append(k)
for k in to_delete:
del hdr[k]

except Exception as e:
await self._error_cb(e)
return hdr
Expand Down
71 changes: 34 additions & 37 deletions nats/js/object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,14 @@
import nats.errors
from nats.js import api
from nats.js.errors import (
BadObjectMetaError, DigestMismatchError, InvalidObjectNameError,
ObjectAlreadyExists, ObjectDeletedError, ObjectNotFoundError,
NotFoundError, LinkIsABucketError
BadObjectMetaError, DigestMismatchError, ObjectAlreadyExists,
ObjectDeletedError, ObjectNotFoundError, NotFoundError, LinkIsABucketError
)
from nats.js.kv import MSG_ROLLUP_SUBJECT

VALID_BUCKET_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
VALID_KEY_RE = re.compile(r"^[-/_=\.a-zA-Z0-9]+$")


def key_valid(key: str) -> bool:
if len(key) == 0 or key[0] == '.' or key[-1] == '.':
return False
return VALID_KEY_RE.match(key) is not None


if TYPE_CHECKING:
from nats.js import JetStreamContext

Expand Down Expand Up @@ -136,10 +128,6 @@ def __init__(
self._stream = stream
self._js = js

def __sanitize_name(self, name: str) -> str:
name = name.replace(".", "_")
return name.replace(" ", "_")

async def get_info(
self,
name: str,
Expand All @@ -148,10 +136,7 @@ async def get_info(
"""
get_info will retrieve the current information for the object.
"""
obj = self.__sanitize_name(name)

if not key_valid(obj):
raise InvalidObjectNameError
obj = name

meta = OBJ_META_PRE_TEMPLATE.format(
bucket=self._name,
Expand Down Expand Up @@ -185,15 +170,13 @@ async def get_info(
async def get(
self,
name: str,
writeinto: Optional[io.BufferedIOBase] = None,
show_deleted: Optional[bool] = False,
) -> ObjectResult:
"""
get will pull the object from the underlying stream.
"""
obj = self.__sanitize_name(name)

if not key_valid(obj):
raise InvalidObjectNameError
obj = name

# Grab meta info.
info = await self.get_info(obj, show_deleted)
Expand Down Expand Up @@ -222,10 +205,22 @@ async def get(

h = sha256()

executor = None
executor_fn = None
if writeinto:
executor = asyncio.get_running_loop().run_in_executor
if hasattr(writeinto, 'buffer'):
executor_fn = writeinto.buffer.write
else:
executor_fn = writeinto.write

async for msg in sub._message_iterator:
tokens = msg._get_metadata_fields(msg.reply)

result.data += msg.data
if executor:
await executor(None, executor_fn, msg.data)
else:
result.data += msg.data
h.update(msg.data)

# Check if we are done.
Expand Down Expand Up @@ -262,11 +257,7 @@ async def put(
max_chunk_size=OBJ_DEFAULT_CHUNK_SIZE,
)

obj = self.__sanitize_name(meta.name)

if not key_valid(obj):
raise InvalidObjectNameError

obj = meta.name
einfo = None

# Create the new nuid so chunks go on a new subject if the name is re-used.
Expand All @@ -285,14 +276,18 @@ async def put(
pass

# Normalize based on type but treat all as readers.
# FIXME: Need an async based reader as well.
executor = None
if isinstance(data, str):
data = io.BytesIO(data.encode())
elif isinstance(data, bytes):
data = io.BytesIO(data)
elif (not isinstance(data, io.BufferedIOBase)):
# Only allowing buffered readers at the moment.
raise TypeError("nats: subtype of io.BufferedIOBase was expected")
elif hasattr(data, 'readinto') or isinstance(data, io.BufferedIOBase):
# Need to delegate to a threaded executor to avoid blocking.
executor = asyncio.get_running_loop().run_in_executor
elif hasattr(data, 'buffer') or isinstance(data, io.TextIOWrapper):
data = data.buffer
else:
raise TypeError("nats: invalid type for object store")

info = api.ObjectInfo(
name=meta.name,
Expand All @@ -312,7 +307,12 @@ async def put(

while True:
try:
n = data.readinto(chunk)
n = None
if executor:
n = await executor(None, data.readinto, chunk)
else:
n = data.readinto(chunk)

if n == 0:
break
payload = chunk[:n]
Expand Down Expand Up @@ -506,10 +506,7 @@ async def delete(self, name: str) -> ObjectResult:
"""
delete will delete the object.
"""
obj = self.__sanitize_name(name)

if not key_valid(obj):
raise InvalidObjectNameError
obj = name

# Grab meta info.
info = await self.get_info(obj)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
name="nats-py",
extras_require={
'nkeys': ['nkeys'],
'aiohttp': ['aiohttp'],
'fast_parse': ['fast-mail-parser'],
}
)
111 changes: 101 additions & 10 deletions tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,6 @@ async def test_fetch_headers(self):
assert msg.header['AAA-AAA-AAA'] == 'a'
assert msg.header['AAA-BBB-AAA'] == ''

# FIXME: An unprocessable key makes the rest of the header be invalid.
await js.publish(
"test.nats.1",
b'third_msg',
Expand All @@ -654,7 +653,7 @@ async def test_fetch_headers(self):
}
)
msgs = await sub.fetch(1)
assert msgs[0].header == None
assert msgs[0].header == {'AAA-BBB-AAA': 'b'}

msg = await js.get_msg("test-nats", 4)
assert msg.header == None
Expand Down Expand Up @@ -2875,17 +2874,21 @@ async def error_handler(e):
tmp.write(ls.encode())
tmp.close()

with pytest.raises(TypeError):
with open(tmp.name) as f:
info = await obs.put("tmp", f)

with open(tmp.name) as f:
info = await obs.put("tmp", f.buffer)
assert info.name == "tmp"
assert info.size == 1048609
assert info.chunks == 9
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

with open(tmp.name) as f:
info = await obs.put("tmp2", f)
assert info.name == "tmp2"
assert info.size == 1048609
assert info.chunks == 9
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

# By default this reads the complete data.
obr = await obs.get("tmp")
info = obr.info
assert info.name == "tmp"
Expand All @@ -2900,9 +2903,31 @@ async def error_handler(e):
assert info.chunks == 1

# Using a local file but not as a buffered reader.
with pytest.raises(TypeError):
with open("pyproject.toml") as f:
await obs.put("pyproject", f)
with open("pyproject.toml") as f:
info = await obs.put("pyproject2", f)
assert info.name == "pyproject2"
assert info.chunks == 1

# Write into file without getting complete data.
w = tempfile.NamedTemporaryFile(delete=False)
w.close()
with open(w.name, 'w') as f:
obr = await obs.get("tmp", writeinto=f)
assert obr.data == b''
assert obr.info.size == 1048609
assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

w2 = tempfile.NamedTemporaryFile(delete=False)
w2.close()
with open(w2.name, 'w') as f:
obr = await obs.get("tmp", writeinto=f.buffer)
assert obr.data == b''
assert obr.info.size == 1048609
assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

with open(w2.name) as f:
result = f.read(-1)
assert len(result) == 1048609

await nc.close()

Expand Down Expand Up @@ -2947,7 +2972,7 @@ async def error_handler(e):
with pytest.raises(nats.js.errors.NotFoundError):
await obs.get("tmp")

with pytest.raises(nats.js.errors.InvalidObjectNameError):
with pytest.raises(nats.js.errors.NotFoundError):
await obs.get("")

res = await js.delete_object_store(bucket="sample")
Expand Down Expand Up @@ -3151,6 +3176,72 @@ async def error_handler(e):

await nc.close()

@async_test
async def test_object_aiofiles(self):
try:
import aiofiles
except ImportError:
pytest.skip("aiofiles not installed")

errors = []

async def error_handler(e):
print("Error:", e, type(e))
errors.append(e)

nc = await nats.connect(error_cb=error_handler)
js = nc.jetstream()

# Create an 8MB object.
obs = await js.create_object_store(bucket="big")
ls = ''.join("A" for _ in range(0, 1 * 1024 * 1024 + 33))
w = io.BytesIO(ls.encode())
info = await obs.put("big", w)
assert info.name == "big"
assert info.size == 1048609
assert info.chunks == 9
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

# Create actual file and put it in a bucket.
tmp = tempfile.NamedTemporaryFile(delete=False)
tmp.write(ls.encode())
tmp.close()

async with aiofiles.open(tmp.name) as f:
info = await obs.put("tmp", f)
assert info.name == "tmp"
assert info.size == 1048609
assert info.chunks == 9
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

async with aiofiles.open(tmp.name) as f:
info = await obs.put("tmp2", f)
assert info.name == "tmp2"
assert info.size == 1048609
assert info.chunks == 9
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

obr = await obs.get("tmp")
info = obr.info
assert info.name == "tmp"
assert info.size == 1048609
assert len(obr.data) == info.size # no reader reads whole file.
assert info.chunks == 9
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='

# Using a local file.
async with aiofiles.open("pyproject.toml") as f:
info = await obs.put("pyproject", f.buffer)
assert info.name == "pyproject"
assert info.chunks == 1

async with aiofiles.open("pyproject.toml") as f:
info = await obs.put("pyproject2", f)
assert info.name == "pyproject2"
assert info.chunks == 1

await nc.close()


class ConsumerReplicasTest(SingleJetStreamServerTestCase):

Expand Down
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ def setUp(self):
server = NATSD(
port=4555, config_file=get_config_file("conf/no_auth_user.conf")
)
server.debug = True
self.server_pool.append(server)
for natsd in self.server_pool:
start_natsd(natsd)
Expand Down

0 comments on commit 533bd8f

Please sign in to comment.