From 35094f6d27e52f8e8d401a01ebbce7113dbea9f4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 20 Dec 2023 18:02:04 -0800 Subject: [PATCH] Implement streaming for xml output parser --- .../langchain/langchain/output_parsers/xml.py | 80 ++++++++++++++++++- .../output_parsers/test_xml_parser.py | 5 ++ 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/xml.py b/libs/langchain/langchain/output_parsers/xml.py index 794d68c6b9eba..c523b4d10382a 100644 --- a/libs/langchain/langchain/output_parsers/xml.py +++ b/libs/langchain/langchain/output_parsers/xml.py @@ -1,13 +1,15 @@ import re import xml.etree.ElementTree as ET -from typing import Any, Dict, List, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union -from langchain_core.output_parsers import BaseOutputParser +from langchain_core.messages import BaseMessage +from langchain_core.output_parsers.transform import BaseTransformOutputParser +from langchain_core.runnables.utils import AddableDict from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS -class XMLOutputParser(BaseOutputParser): +class XMLOutputParser(BaseTransformOutputParser): """Parse an output using xml format.""" tags: Optional[List[str]] = None @@ -33,6 +35,70 @@ def parse(self, text: str) -> Dict[str, List[Any]]: else: raise ValueError(f"Could not parse output: {text}") + def _transform( + self, input: Iterator[Union[str, BaseMessage]] + ) -> Iterator[AddableDict[str, Any]]: + parser = ET.XMLPullParser(["start", "end"]) + current_path: List[str] = [] + current_path_has_children = False + for chunk in input: + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + continue + chunk = chunk_content + # pass chunk to parser + parser.feed(chunk) + # yield all events + for event, elem in parser.read_events(): + if event == "start": + # update current path + current_path.append(elem.tag) + current_path_has_children = False + elif event == "end": + # remove last element from current path + current_path.pop() + # yield element + if not current_path_has_children: + yield nested_element(current_path, elem) + # prevent yielding of parent element + current_path_has_children = True + # close parser + parser.close() + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[AddableDict[str, Any]]: + parser = ET.XMLPullParser(["start", "end"]) + current_path: List[str] = [] + current_path_has_children = False + async for chunk in input: + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + continue + chunk = chunk_content + # pass chunk to parser + parser.feed(chunk) + # yield all events + for event, elem in parser.read_events(): + if event == "start": + # update current path + current_path.append(elem.tag) + current_path_has_children = False + elif event == "end": + # remove last element from current path + current_path.pop() + # yield element + if not current_path_has_children: + yield nested_element(current_path, elem) + # prevent yielding of parent element + current_path_has_children = True + # close parser + parser.close() + def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: """Converts xml tree to python dictionary.""" result: Dict[str, List[Any]] = {root.tag: []} @@ -46,3 +112,11 @@ def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: @property def _type(self) -> str: return "xml" + + +def nested_element(path: List[str], elem: ET.Element) -> Any: + """Get nested element from path.""" + if len(path) == 0: + return AddableDict({elem.tag: elem.text}) + else: + return AddableDict({path[0]: [nested_element(path[1:], elem)]}) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_xml_parser.py index 3830d25f8d37b..36f82746a1df9 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_xml_parser.py @@ -31,6 +31,11 @@ def test_xml_output_parser(result: str) -> None: xml_result = xml_parser.parse(result) assert DEF_RESULT_EXPECTED == xml_result + assert list(xml_parser.transform(result)) == [ + {"foo": [{"bar": [{"baz": None}]}]}, + {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, + {"foo": [{"baz": "tag"}]}, + ] @pytest.mark.parametrize("result", ["foo>", "