Skip to content

Commit

Permalink
Compress create feedback (#1452)
Browse files Browse the repository at this point in the history
### Purpose
Compress feedback along with runs using the same compression buffer and
background thread

### Changes
* Renamed compressed runs variables to compressed traces
* Compress feedback along with runs
* Added additional client ref to `num_refs` to account for compression
bg thread.

### Tests
Tested locally to ensure feedback ingest works with and without
compression
Added unit tests for feedback ingest with compression
  • Loading branch information
angus-langchain authored Jan 23, 2025
1 parent 67fa09b commit c0f8ee7
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 69 deletions.
62 changes: 33 additions & 29 deletions python/langsmith/_internal/_background_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils
from langsmith._internal._compressed_runs import CompressedRuns
from langsmith._internal._compressed_traces import CompressedTraces
from langsmith._internal._constants import (
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER,
_AUTO_SCALE_UP_NTHREADS_LIMIT,
Expand Down Expand Up @@ -102,13 +102,13 @@ def _tracing_thread_drain_queue(
def _tracing_thread_drain_compressed_buffer(
client: Client, size_limit: int = 100, size_limit_bytes: int | None = 20_971_520
) -> Tuple[Optional[io.BytesIO], Optional[Tuple[int, int]]]:
if client.compressed_runs is None:
if client.compressed_traces is None:
return None, None
with client.compressed_runs.lock:
client.compressed_runs.compressor_writer.flush()
current_size = client.compressed_runs.buffer.tell()
with client.compressed_traces.lock:
client.compressed_traces.compressor_writer.flush()
current_size = client.compressed_traces.buffer.tell()

pre_compressed_size = client.compressed_runs.uncompressed_size
pre_compressed_size = client.compressed_traces.uncompressed_size

if size_limit is not None and size_limit <= 0:
raise ValueError(f"size_limit must be positive; got {size_limit}")
Expand All @@ -118,22 +118,24 @@ def _tracing_thread_drain_compressed_buffer(
)

if (size_limit_bytes is None or current_size < size_limit_bytes) and (
size_limit is None or client.compressed_runs.run_count < size_limit
size_limit is None or client.compressed_traces.trace_count < size_limit
):
return None, None

# Write final boundary and close compression stream
client.compressed_runs.compressor_writer.write(f"--{_BOUNDARY}--\r\n".encode())
client.compressed_runs.compressor_writer.close()
client.compressed_traces.compressor_writer.write(
f"--{_BOUNDARY}--\r\n".encode()
)
client.compressed_traces.compressor_writer.close()

filled_buffer = client.compressed_runs.buffer
filled_buffer = client.compressed_traces.buffer

compressed_runs_info = (pre_compressed_size, current_size)
compressed_traces_info = (pre_compressed_size, current_size)

client.compressed_runs.reset()
client.compressed_traces.reset()

filled_buffer.seek(0)
return (filled_buffer, compressed_runs_info)
return (filled_buffer, compressed_traces_info)


def _tracing_thread_handle_batch(
Expand Down Expand Up @@ -217,6 +219,10 @@ def tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"]
use_multipart = batch_ingest_config.get("use_multipart_endpoint", False)

sub_threads: List[threading.Thread] = []
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
num_known_refs = 3

disable_compression = ls_utils.get_env_var("DISABLE_RUN_COMPRESSION")
if not ls_utils.is_truish(disable_compression) and use_multipart:
if not (client.info.instance_flags or {}).get(
Expand All @@ -228,16 +234,14 @@ def tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
)
else:
client._futures = set()
client.compressed_runs = CompressedRuns()
client.compressed_traces = CompressedTraces()
client._data_available_event = threading.Event()
threading.Thread(
target=tracing_control_thread_func_compress_parallel,
args=(weakref.ref(client),),
).start()

sub_threads: List[threading.Thread] = []
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
num_known_refs = 3
num_known_refs += 1

def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
Expand Down Expand Up @@ -293,7 +297,7 @@ def tracing_control_thread_func_compress_parallel(
return

if (
client.compressed_runs is None
client.compressed_traces is None
or client._data_available_event is None
or client._futures is None
):
Expand Down Expand Up @@ -341,28 +345,28 @@ def keep_thread_active() -> bool:
if triggered:
client._data_available_event.clear()

data_stream, compressed_runs_info = _tracing_thread_drain_compressed_buffer(
client, size_limit, size_limit_bytes
)
data_stream, compressed_traces_info = (
_tracing_thread_drain_compressed_buffer
)(client, size_limit, size_limit_bytes)
# If we have data, submit the send request
if data_stream is not None:
try:
future = HTTP_REQUEST_THREAD_POOL.submit(
client._send_compressed_multipart_req,
data_stream,
compressed_runs_info,
compressed_traces_info,
)
client._futures.add(future)
except RuntimeError:
client._send_compressed_multipart_req(
data_stream,
compressed_runs_info,
compressed_traces_info,
)
last_flush_time = time.monotonic()

else:
if (time.monotonic() - last_flush_time) >= flush_interval:
data_stream, compressed_runs_info = (
data_stream, compressed_traces_info = (
_tracing_thread_drain_compressed_buffer(
client, size_limit=1, size_limit_bytes=1
)
Expand All @@ -374,20 +378,20 @@ def keep_thread_active() -> bool:
HTTP_REQUEST_THREAD_POOL.submit(
client._send_compressed_multipart_req,
data_stream,
compressed_runs_info,
compressed_traces_info,
)
]
)
except RuntimeError:
client._send_compressed_multipart_req(
data_stream,
compressed_runs_info,
compressed_traces_info,
)
last_flush_time = time.monotonic()

# Drain the buffer on exit (final flush)
try:
final_data_stream, compressed_runs_info = (
final_data_stream, compressed_traces_info = (
_tracing_thread_drain_compressed_buffer(
client, size_limit=1, size_limit_bytes=1
)
Expand All @@ -399,14 +403,14 @@ def keep_thread_active() -> bool:
HTTP_REQUEST_THREAD_POOL.submit(
client._send_compressed_multipart_req,
final_data_stream,
compressed_runs_info,
compressed_traces_info,
)
]
)
except RuntimeError:
client._send_compressed_multipart_req(
final_data_stream,
compressed_runs_info,
compressed_traces_info,
)

except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
compression_level = ls_utils.get_env_var("RUN_COMPRESSION_LEVEL", 3)


class CompressedRuns:
class CompressedTraces:
def __init__(self):
self.buffer = io.BytesIO()
self.run_count = 0
self.trace_count = 0
self.lock = threading.Lock()
self.uncompressed_size = 0

Expand All @@ -21,7 +21,7 @@ def __init__(self):

def reset(self):
self.buffer = io.BytesIO()
self.run_count = 0
self.trace_count = 0
self.uncompressed_size = 0

self.compressor_writer = ZstdCompressor(
Expand Down
32 changes: 21 additions & 11 deletions python/langsmith/_internal/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import os
import uuid
from io import BufferedReader
from typing import Dict, Literal, Optional, Union, cast
from typing import Dict, Iterable, Literal, Optional, Tuple, Union, cast

from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._compressed_runs import CompressedRuns
from langsmith._internal._compressed_traces import CompressedTraces
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
from langsmith._internal._serde import dumps_json as _dumps_json

Expand Down Expand Up @@ -292,11 +292,10 @@ def serialized_run_operation_to_multipart_parts_and_context(
)


def compress_multipart_parts_and_context(
def encode_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
compressed_runs: CompressedRuns,
boundary: str,
) -> None:
) -> Iterable[Tuple[bytes, Union[bytes, BufferedReader]]]:
for part_name, (filename, data, content_type, headers) in parts_and_context.parts:
header_parts = [
f"--{boundary}\r\n",
Expand All @@ -314,18 +313,29 @@ def compress_multipart_parts_and_context(
]
)

compressed_runs.compressor_writer.write("".join(header_parts).encode())
yield ("".join(header_parts).encode(), data)


def compress_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
compressed_traces: CompressedTraces,
boundary: str,
) -> None:
for headers, data in encode_multipart_parts_and_context(
parts_and_context, boundary
):
compressed_traces.compressor_writer.write(headers)

if isinstance(data, (bytes, bytearray)):
compressed_runs.uncompressed_size += len(data)
compressed_runs.compressor_writer.write(data)
compressed_traces.uncompressed_size += len(data)
compressed_traces.compressor_writer.write(data)
else:
if isinstance(data, BufferedReader):
encoded_data = data.read()
else:
encoded_data = str(data).encode()
compressed_runs.uncompressed_size += len(encoded_data)
compressed_runs.compressor_writer.write(encoded_data)
compressed_traces.uncompressed_size += len(encoded_data)
compressed_traces.compressor_writer.write(encoded_data)

# Write part terminator
compressed_runs.compressor_writer.write(b"\r\n")
compressed_traces.compressor_writer.write(b"\r\n")
Loading

0 comments on commit c0f8ee7

Please sign in to comment.