diff --git a/python/langsmith/pytest_plugin.py b/python/langsmith/pytest_plugin.py index 7ddbdd0f4..5abeadc0a 100644 --- a/python/langsmith/pytest_plugin.py +++ b/python/langsmith/pytest_plugin.py @@ -1,14 +1,23 @@ """LangSmith Pytest hooks.""" +import importlib.util +import json +import time +from threading import Lock + import pytest from langsmith.testing._internal import test as ls_test -def pytest_configure(config): - """Register the 'langsmith' marker.""" - config.addinivalue_line( - "markers", "langsmith: mark test to be tracked in LangSmith" +def pytest_addoption(parser): + """Set CLI options.""" + group = parser.getgroup("custom-output") + group.addoption( + "--use-custom-output", + action="store_true", + default=False, + help="Enable custom JSON output instead of the default terminal output.", ) @@ -23,5 +32,179 @@ def pytest_runtest_call(item): # Wrap the test function with our test decorator original_func = item.obj item.obj = ls_test(**kwargs)(original_func) - + request_obj = getattr(item, "_request", None) + if request_obj is not None and "request" not in item.funcargs: + item.funcargs["request"] = request_obj + item._fixtureinfo.argnames += ("request",) yield + + +def pytest_sessionstart(session): + """Conditionally remove the terminalreporter plugin.""" + if session.config.getoption("--use-custom-output"): + tr = session.config.pluginmanager.get_plugin("terminalreporter") + if tr: + session.config.pluginmanager.unregister(tr) + + +class CustomOutputPlugin: + """Plugin for rendering LangSmith results.""" + + def __init__(self): + """Initialize.""" + from rich.console import Console + + self.process_status = {} # Track process status + self.status_lock = Lock() # Thread-safe updates + self.console = Console() + self.live = None + self.test_count = 0 + + def update_process_status(self, process_id, status): + """Update test results.""" + with self.status_lock: + current_status = self.process_status.get(process_id, {}) + if status.get("feedback"): + current_status["feedback"] = { + **current_status.get("feedback", {}), + **status.pop("feedback"), + } + if status.get("inputs"): + current_status["inputs"] = { + **current_status.get("inputs", {}), + **status.pop("inputs"), + } + if status.get("reference_outputs"): + current_status["reference_outputs"] = { + **current_status.get("reference_outputs", {}), + **status.pop("reference_outputs"), + } + if status.get("outputs"): + current_status["outputs"] = { + **current_status.get("outputs", {}), + **status.pop("outputs"), + } + self.process_status[process_id] = {**current_status, **status} + if self.live: + self.live.update(self.generate_table()) + + def pytest_collection_modifyitems(self, items): + """Get total test count for progress tracking.""" + self.test_count = len(items) + + def pytest_runtest_logstart(self, nodeid): + """Initialize live display when first test starts.""" + from rich.live import Live + + if not self.live: + self.live = Live(self.generate_table(), refresh_per_second=4) + self.live.start() + self.update_process_status( + nodeid, {"status": "running", "start_time": time.time()} + ) + + def generate_table(self): + """Generate results table.""" + from rich.table import Table + + table = Table() + table.add_column("Test") + table.add_column("Inputs") + table.add_column("Ref outputs") + table.add_column("Outputs") + table.add_column("Status") + table.add_column("Feedback") + table.add_column("Duration") + + # Test, inputs, ref outputs, outputs col width + max_status = len("status") + max_feedback = len("feedback") + max_duration = len("duration") + for pid, status in self.process_status.items(): + duration = status.get("end_time", time.time()) - status["start_time"] + feedback = "\n".join( + f"{k}: {v}" for k, v in status.get("feedback", {}).items() + ) + max_duration = max(len(f"{duration:.2f}s"), max_duration) + max_status = max(len(status["status"]), max_status) + max_feedback = max(len(feedback), max_feedback) + + max_dynamic_col_width = ( + self.console.width - (max_status + max_feedback + max_duration) + ) // 4 + + for pid, status in self.process_status.items(): + status_color = { + "running": "yellow", + "passed": "green", + "failed": "red", + "skipped": "cyan", + }.get(status["status"], "white") + + duration = status.get("end_time", time.time()) - status["start_time"] + feedback = "\n".join( + f"{k}: {v}" for k, v in status.get("feedback", {}).items() + ) + inputs = json.dumps(status.get("inputs", {})) + reference_outputs = json.dumps(status.get("reference_outputs", {})) + outputs = json.dumps(status.get("outputs", {})) + table.add_row( + _abbreviate_test_name(str(pid), max_len=max_dynamic_col_width), + _abbreviate(inputs, max_len=max_dynamic_col_width), + _abbreviate(reference_outputs, max_len=max_dynamic_col_width), + _abbreviate(outputs, max_len=max_dynamic_col_width), + f"[{status_color}]{status['status']}[/{status_color}]", + feedback, + f"{duration:.2f}s", + ) + return table + + def pytest_configure(self, config): + """Disable warning reporting and show no warnings in output.""" + # Disable general warning reporting + config.option.showwarnings = False + + # Disable warning summary + reporter = config.pluginmanager.get_plugin("warnings-plugin") + if reporter: + reporter.warning_summary = lambda *args, **kwargs: None + + # def pytest_runtest_logreport(self, report): + # if hasattr(report, "warnings"): + # report.warnings = [] + # pass + + +def pytest_configure(config): + """Register the 'langsmith' marker.""" + config.addinivalue_line( + "markers", "langsmith: mark test to be tracked in LangSmith" + ) + if config.getoption("--use-custom-output"): + if not importlib.util.find_spec("rich"): + msg = ( + "Must have 'rich' installed to use --use-custom-output. Please install " + "with: `pip install -U 'langsmith[pytest]'`" + ) + raise ValueError(msg) + config.pluginmanager.register(CustomOutputPlugin(), "custom_output_plugin") + # Suppress warnings summary + config.option.showwarnings = False + + +def _abbreviate(x: str, max_len: int) -> str: + if len(x) > max_len: + return x[: max_len - 3] + "..." + else: + return x + + +def _abbreviate_test_name(test_name: str, max_len: int) -> str: + if len(test_name) > max_len: + file, test = test_name.split("::") + if len(".py::" + test) > max_len: + return "..." + test[-(max_len - 3) :] + file_len = max_len - len("...::" + test) + return "..." + file[-file_len:] + "::" + test + else: + return test_name diff --git a/python/langsmith/testing/_internal.py b/python/langsmith/testing/_internal.py index 777d4d9ec..269f3e1fa 100644 --- a/python/langsmith/testing/_internal.py +++ b/python/langsmith/testing/_internal.py @@ -10,6 +10,7 @@ import logging import os import threading +import time import uuid import warnings from collections import defaultdict @@ -296,10 +297,16 @@ async def async_wrapper(*test_args: Any, **test_kwargs: Any): return async_wrapper @functools.wraps(func) - def wrapper(*test_args: Any, **test_kwargs: Any): + def wrapper(*test_args: Any, request: Any, **test_kwargs: Any): if disable_tracking: return func(*test_args, **test_kwargs) - _run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra) + _run_test( + func, + *test_args, + request=request, + **test_kwargs, + langtest_extra=langtest_extra, + ) return wrapper @@ -383,16 +390,9 @@ def _start_experiment( return client.read_project(project_name=experiment_name) -# Track the number of times a parameter has been used in a test -# This is to ensure that we can uniquely identify each test case -# defined using pytest.mark.parametrize -_param_dict: dict = defaultdict(lambda: defaultdict(int)) - - def _get_example_id( func: Callable, inputs: dict, suite_id: uuid.UUID ) -> Tuple[uuid.UUID, str]: - # global _param_dict try: file_path = str(Path(inspect.getfile(func)).relative_to(Path.cwd())) except ValueError: @@ -407,9 +407,7 @@ def _get_example_id( return uuid.uuid5(uuid.NAMESPACE_DNS, identifier), identifier[len(str(suite_id)) :] -def _end_tests( - test_suite: _LangSmithTestSuite, -): +def _end_tests(test_suite: _LangSmithTestSuite): git_info = ls_env.get_git_info() or {} test_suite.client.update_project( test_suite.experiment_id, @@ -495,12 +493,29 @@ def get_version(self) -> Optional[datetime.datetime]: return self._version def submit_result( - self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False + self, + run_id: uuid.UUID, + error: Optional[str] = None, + skipped: bool = False, + plugin=None, + nodeid=None, ) -> None: - self._executor.submit(self._submit_result, run_id, error, skipped=skipped) + self._executor.submit( + self._submit_result, + run_id, + error, + skipped=skipped, + plugin=plugin, + nodeid=nodeid, + ) def _submit_result( - self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False + self, + run_id: uuid.UUID, + error: Optional[str] = None, + skipped: bool = False, + plugin=None, + nodeid=None, ) -> None: if error: if skipped: @@ -511,16 +526,21 @@ def _submit_result( score=None, comment=f"Skipped: {repr(error)}", ) + status = "skipped" else: self.client.create_feedback( run_id, key="pass", score=0, comment=f"Error: {repr(error)}" ) + status = "failed" else: self.client.create_feedback( run_id, key="pass", score=1, ) + status = "passed" + if plugin and nodeid: + plugin.update_process_status(nodeid, {"status": status}) def sync_example( self, @@ -529,6 +549,8 @@ def sync_example( inputs: Optional[dict] = None, outputs: Optional[dict] = None, metadata: Optional[dict] = None, + plugin=None, + nodeid=None, ) -> None: future = self._executor.submit( self._sync_example, @@ -536,6 +558,8 @@ def sync_example( inputs, outputs, metadata.copy() if metadata else metadata, + plugin, + nodeid, ) with self._lock: self._example_futures[example_id].append(future) @@ -546,6 +570,8 @@ def _sync_example( inputs: Optional[dict], outputs: Optional[dict], metadata: Optional[dict], + plugin: Any, + nodeid: Any, ) -> None: inputs_ = _serde_example_values(inputs) if inputs else inputs outputs_ = _serde_example_values(outputs) if outputs else outputs @@ -576,6 +602,11 @@ def _sync_example( if example.modified_at: self.update_version(example.modified_at) + if plugin and nodeid: + update = {"inputs": inputs, "reference_outputs": outputs} + update = {k: v for k, v in update.items() if v is not None} + plugin.update_process_status(nodeid, update) + def _submit_feedback( self, run_id: ID_TYPE, feedback: Union[dict, list], **kwargs: Any ): @@ -585,9 +616,14 @@ def _submit_feedback( self._create_feedback, run_id=run_id, feedback=fb, **kwargs ) - def _create_feedback(self, run_id: ID_TYPE, feedback: dict, **kwargs: Any) -> None: + def _create_feedback( + self, run_id: ID_TYPE, feedback: dict, plugin=None, nodeid=None, **kwargs: Any + ) -> None: trace_id = self.client.read_run(run_id).trace_id self.client.create_feedback(trace_id, **feedback, **kwargs) + if plugin and nodeid: + val = feedback["score"] if "score" in feedback else feedback["val"] + plugin.update_process_status(nodeid, {"feedback": {feedback["key"]: val}}) def shutdown(self): self._executor.shutdown(wait=True) @@ -614,14 +650,37 @@ def _end_run(self, run_tree, example_id, outputs) -> None: class _TestCase: - def __init__(self, test_suite: _LangSmithTestSuite, example_id: uuid.UUID) -> None: + def __init__( + self, + test_suite: _LangSmithTestSuite, + example_id: uuid.UUID, + plugin=None, + nodeid=None, + ) -> None: self.test_suite = test_suite self.example_id = example_id + self.plugin = plugin + self.nodeid = nodeid def sync_example( self, *, inputs: Optional[dict] = None, outputs: Optional[dict] = None ) -> None: - self.test_suite.sync_example(self.example_id, inputs=inputs, outputs=outputs) + self.test_suite.sync_example( + self.example_id, + inputs=inputs, + outputs=outputs, + plugin=self.plugin, + nodeid=self.nodeid, + ) + + def submit_feedback(self, *args, **kwargs: Any): + self.test_suite._submit_feedback( + *args, **kwargs, plugin=self.plugin, nodeid=self.nodeid + ) + + def log_outputs(self, outputs: dict) -> None: + if self.plugin and self.nodeid: + self.plugin.update_process_status(self.nodeid, {"outputs": outputs}) _TEST_CASE = contextvars.ContextVar[Optional[_TestCase]]("_TEST_CASE", default=None) @@ -669,13 +728,21 @@ def _ensure_example( def _run_test( - func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any + func: Callable, + *test_args: Any, + request: Any, + langtest_extra: _UTExtra, + **test_kwargs: Any, ) -> None: test_suite, example_id = _ensure_example( - func, *test_args, **test_kwargs, langtest_extra=langtest_extra + func, *test_args, **test_kwargs, langtest_extra=langtest_extra, request=request + ) + plugin = request.config.pluginmanager.get_plugin("custom_output_plugin") + _TEST_CASE.set( + _TestCase(test_suite, example_id, plugin=plugin, nodeid=request.node.nodeid) ) - _TEST_CASE.set(_TestCase(test_suite, example_id)) run_id = uuid.uuid4() + # Get the plugin instance def _test(): func_inputs = rh._get_inputs_safe( @@ -690,15 +757,27 @@ def _test(): project_name=test_suite.name, exceptions_to_handle=(SkipException,), ) as run_tree: + if plugin and request.node.nodeid: + plugin.update_process_status( + request.node.nodeid, {"start_time": time.time()} + ) try: result = func(*test_args, **test_kwargs) except SkipException as e: - test_suite.submit_result(run_id, error=repr(e), skipped=True) + test_suite.submit_result( + run_id, + error=repr(e), + skipped=True, + plugin=plugin, + nodeid=request.node.nodeid, + ) outputs = {"skipped_reason": repr(e)} test_suite.end_run(run_tree, example_id, outputs) raise e except BaseException as e: - test_suite.submit_result(run_id, error=repr(e)) + test_suite.submit_result( + run_id, error=repr(e), plugin=plugin, nodeid=request.node.nodeid + ) raise e else: outputs = ( @@ -707,8 +786,15 @@ def _test(): else {"output": result} ) test_suite.end_run(run_tree, example_id, outputs) + finally: + if plugin and request.node.nodeid: + plugin.update_process_status( + request.node.nodeid, {"end_time": time.time()} + ) try: - test_suite.submit_result(run_id, error=None) + test_suite.submit_result( + run_id, error=None, plugin=plugin, nodeid=request.node.nodeid + ) except BaseException as e: logger.warning(f"Failed to create feedback for run_id {run_id}: {e}") @@ -854,7 +940,8 @@ def log_outputs(outputs: dict, /) -> None: ... assert result == 2 """ run_tree = rh.get_current_run_tree() - if not run_tree: + test_case = _TEST_CASE.get() + if not run_tree or not test_case: msg = ( "log_outputs should only be called within a pytest test decorated with " "@pytest.mark.langsmith, and with tracing enabled (by setting the " @@ -862,6 +949,7 @@ def log_outputs(outputs: dict, /) -> None: ) raise ValueError(msg) run_tree.add_outputs(outputs) + test_case.log_outputs(outputs) @warn_beta @@ -961,9 +1049,7 @@ def log_feedback( kwargs["source_run_id"] = run_tree.id else: run_id = run_tree.trace_id - test_case.test_suite._submit_feedback( - run_id, cast(Union[list, dict], feedback), **kwargs - ) + test_case.submit_feedback(run_id, cast(Union[list, dict], feedback), **kwargs) @warn_beta diff --git a/python/poetry.lock b/python/poetry.lock index a31c97af7..beb04add0 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -711,6 +711,30 @@ files = [ {file = "langsmith_pyo3-0.1.0rc2.tar.gz", hash = "sha256:30eb26aa33deca44eb9210b77d478ec2157a0cb51f96da30f87072dd5912e3ed"}, ] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = true +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "marshmallow" version = "3.22.0" @@ -730,6 +754,17 @@ dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.13)", "sphinx (==8.0.2)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] tests = ["pytest", "pytz", "simplejson"] +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = true +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "multidict" version = "6.1.0" @@ -1411,6 +1446,20 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pygments" +version = "2.19.1" +description = "Pygments is a syntax highlighting package written in Python." +optional = true +python-versions = ">=3.8" +files = [ + {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, + {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pyperf" version = "2.8.0" @@ -1676,6 +1725,25 @@ files = [ [package.dependencies] requests = ">=2.0.1,<3.0.0" +[[package]] +name = "rich" +version = "13.9.4" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"}, + {file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "ruff" version = "0.6.9" @@ -2311,9 +2379,10 @@ cffi = ["cffi (>=1.11)"] [extras] compression = ["zstandard"] langsmith-pyo3 = ["langsmith-pyo3"] +pytest = ["rich"] vcr = [] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "19b288e10f9d6c040798efb74ae2778103abdc87c298b8c3eb21f7685db56642" +content-hash = "266ff8b75a6f5faa9e334766de1fc130ad66bbae65d2fcda1dd346d73850c9b5" diff --git a/python/pyproject.toml b/python/pyproject.toml index b6e1a2dc3..d59066833 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -39,6 +39,7 @@ requests-toolbelt = "^1.0.0" langsmith-pyo3 = { version = "^0.1.0rc2", optional = true } # Enabled via `compression` extra: `pip install langsmith[compression]`. zstandard = { version = "^0.23.0", optional = true } +rich = {version = "^13.9.4", optional = true} [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" @@ -78,6 +79,7 @@ pytest-socket = "^0.7.0" vcr = ["vcrpy"] langsmith_pyo3 = ["langsmith-pyo3"] compression = ["zstandard"] +pytest = ["rich"] [build-system] requires = ["poetry-core"]