Skip to content

Commit

Permalink
Implement streaming for xml output parser
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 21, 2023
1 parent 320c3ae commit 35094f6
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 3 deletions.
80 changes: 77 additions & 3 deletions libs/langchain/langchain/output_parsers/xml.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import re
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

from langchain_core.output_parsers import BaseOutputParser
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict

from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS


class XMLOutputParser(BaseOutputParser):
class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format."""

tags: Optional[List[str]] = None
Expand All @@ -33,6 +35,70 @@ def parse(self, text: str) -> Dict[str, List[Any]]:
else:
raise ValueError(f"Could not parse output: {text}")

def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict[str, Any]]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()

async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict[str, Any]]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()

def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""
result: Dict[str, List[Any]] = {root.tag: []}
Expand All @@ -46,3 +112,11 @@ def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
@property
def _type(self) -> str:
return "xml"


def nested_element(path: List[str], elem: ET.Element) -> Any:
"""Get nested element from path."""
if len(path) == 0:
return AddableDict({elem.tag: elem.text})
else:
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def test_xml_output_parser(result: str) -> None:

xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result
assert list(xml_parser.transform(result)) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]


@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
Expand Down

0 comments on commit 35094f6

Please sign in to comment.