Skip to content

Commit

Permalink
Add unit tests for disabled compression in instance flags and env vars
Browse files Browse the repository at this point in the history
  • Loading branch information
angus-langchain committed Jan 13, 2025
1 parent 8c4b3c6 commit 23080d3
Showing 1 changed file with 198 additions and 0 deletions.
198 changes: 198 additions & 0 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import pytest
import requests
from multipart import MultipartParser, MultipartPart, parse_options_header
from requests_toolbelt import MultipartEncoder
from pydantic import BaseModel
from requests import HTTPError
import zstandard as zstd

import langsmith.env as ls_env
import langsmith.utils as ls_utils
Expand Down Expand Up @@ -2142,3 +2144,199 @@ def test_create_run_with_zstd_compression(mock_session_cls: mock.Mock) -> None:
"Expected the request body to start with zstd magic bytes; "
"it appears runs were not compressed."
)

@patch("langsmith.client.requests.Session")
def test_create_run_without_compression_support(mock_session_cls: mock.Mock) -> None:
"""Test that runs use regular multipart when server doesn't support compression."""
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_session.request.return_value = mock_response
mock_session_cls.return_value = mock_session

with patch.dict("os.environ", {}, clear=True):
info = ls_schemas.LangSmithInfo(
version="0.6.0",
instance_flags={}, # No compression flag
batch_ingest_config=ls_schemas.BatchIngestConfig(
use_multipart_endpoint=True,
size_limit=1,
size_limit_bytes=128,
scale_up_nthreads_limit=4,
scale_up_qsize_trigger=3,
scale_down_nempty_trigger=1,
),
)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=True,
session=mock_session,
info=info,
)

run_id = uuid.uuid4()
inputs = {"key": "there"}
client.create_run(
name="test_run",
run_type="llm",
inputs=inputs,
id=run_id,
trace_id=run_id,
dotted_order=str(run_id),
)

outputs = {"key": "hi there"}

client.update_run(
run_id,
outputs=outputs,
end_time=datetime.now(timezone.utc),
trace_id=run_id,
dotted_order=str(run_id),
)

if client.tracing_queue:
client.tracing_queue.join()

time.sleep(0.1)

post_calls = [
call_obj for call_obj in mock_session.request.mock_calls
if call_obj.args and call_obj.args[0] == "POST"
]
assert len(post_calls) >= 1

payloads = [
(call[2]["headers"], call[2]["data"])
for call in mock_session.request.mock_calls
if call.args and call.args[1].endswith("runs/multipart")
]
if not payloads:
assert False, "No payloads found"

parts: List[MultipartPart] = []
for payload in payloads:
headers, data = payload
assert headers["Content-Type"].startswith("multipart/form-data")
assert isinstance(data, bytes)
boundary = parse_options_header(headers["Content-Type"])[1]["boundary"]
parser = MultipartParser(io.BytesIO(data), boundary)
parts.extend(parser.parts())

assert [p.name for p in parts] == [
f"post.{run_id}",
f"post.{run_id}.inputs",
f"post.{run_id}.outputs",
]
assert [p.headers.get("content-type") for p in parts] == [
"application/json",
"application/json",
"application/json",
]

outputs_parsed = json.loads(parts[2].value)
assert outputs_parsed == outputs
inputs_parsed = json.loads(parts[1].value)
assert inputs_parsed == inputs
run_parsed = json.loads(parts[0].value)
assert run_parsed["trace_id"] == str(run_id)
assert run_parsed["dotted_order"] == str(run_id)

@patch("langsmith.client.requests.Session")
def test_create_run_with_disabled_compression(mock_session_cls: mock.Mock) -> None:
"""Test that runs use regular multipart when compression is explicitly disabled."""
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_session.request.return_value = mock_response
mock_session_cls.return_value = mock_session

with patch.dict("os.environ", {"LANGSMITH_DISABLE_RUN_COMPRESSION": "true"}, clear=True):
info = ls_schemas.LangSmithInfo(
version="0.6.0",
instance_flags={"zstd_compression_enabled": True}, # Enabled on server
batch_ingest_config=ls_schemas.BatchIngestConfig(
use_multipart_endpoint=True,
size_limit=1,
size_limit_bytes=128,
scale_up_nthreads_limit=4,
scale_up_qsize_trigger=3,
scale_down_nempty_trigger=1,
),
)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=True,
session=mock_session,
info=info,
)

run_id = uuid.uuid4()
inputs = {"key": "there"}
client.create_run(
name="test_run",
run_type="llm",
inputs=inputs,
id=run_id,
trace_id=run_id,
dotted_order=str(run_id),
)

outputs = {"key": "hi there"}
client.update_run(
run_id,
outputs=outputs,
end_time=datetime.now(timezone.utc),
trace_id=run_id,
dotted_order=str(run_id),
)

# Let the background threads flush
if client.tracing_queue:
client.tracing_queue.join()

time.sleep(0.1)

post_calls = [
call_obj for call_obj in mock_session.request.mock_calls
if call_obj.args and call_obj.args[0] == "POST"
]
assert len(post_calls) >= 1

payloads = [
(call[2]["headers"], call[2]["data"])
for call in mock_session.request.mock_calls
if call.args and call.args[1].endswith("runs/multipart")
]
if not payloads:
assert False, "No payloads found"

parts: List[MultipartPart] = []
for payload in payloads:
headers, data = payload
assert headers["Content-Type"].startswith("multipart/form-data")
assert isinstance(data, bytes)
boundary = parse_options_header(headers["Content-Type"])[1]["boundary"]
parser = MultipartParser(io.BytesIO(data), boundary)
parts.extend(parser.parts())

assert [p.name for p in parts] == [
f"post.{run_id}",
f"post.{run_id}.inputs",
f"post.{run_id}.outputs",
]
assert [p.headers.get("content-type") for p in parts] == [
"application/json",
"application/json",
"application/json",
]

outputs_parsed = json.loads(parts[2].value)
assert outputs_parsed == outputs
inputs_parsed = json.loads(parts[1].value)
assert inputs_parsed == inputs
run_parsed = json.loads(parts[0].value)
assert run_parsed["trace_id"] == str(run_id)
assert run_parsed["dotted_order"] == str(run_id)

0 comments on commit 23080d3

Please sign in to comment.