-
Notifications
You must be signed in to change notification settings - Fork 16.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move json and xml parsers to core (#15026)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
- Loading branch information
Showing
10 changed files
with
841 additions
and
421 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
import re | ||
from json import JSONDecodeError | ||
from typing import Any, Callable, List, Optional | ||
|
||
import jsonpatch # type: ignore[import] | ||
|
||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser | ||
|
||
|
||
def _replace_new_line(match: re.Match[str]) -> str: | ||
value = match.group(2) | ||
value = re.sub(r"\n", r"\\n", value) | ||
value = re.sub(r"\r", r"\\r", value) | ||
value = re.sub(r"\t", r"\\t", value) | ||
value = re.sub(r'(?<!\\)"', r"\"", value) | ||
|
||
return match.group(1) + value + match.group(3) | ||
|
||
|
||
def _custom_parser(multiline_string: str) -> str: | ||
""" | ||
The LLM response for `action_input` may be a multiline | ||
string containing unescaped newlines, tabs or quotes. This function | ||
replaces those characters with their escaped counterparts. | ||
(newlines in JSON must be double-escaped: `\\n`) | ||
""" | ||
if isinstance(multiline_string, (bytes, bytearray)): | ||
multiline_string = multiline_string.decode() | ||
|
||
multiline_string = re.sub( | ||
r'("action_input"\:\s*")(.*)(")', | ||
_replace_new_line, | ||
multiline_string, | ||
flags=re.DOTALL, | ||
) | ||
|
||
return multiline_string | ||
|
||
|
||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py | ||
# MIT License | ||
def parse_partial_json(s: str, *, strict: bool = False) -> Any: | ||
"""Parse a JSON string that may be missing closing braces. | ||
Args: | ||
s: The JSON string to parse. | ||
strict: Whether to use strict parsing. Defaults to False. | ||
Returns: | ||
The parsed JSON object as a Python dictionary. | ||
""" | ||
# Attempt to parse the string as-is. | ||
try: | ||
return json.loads(s, strict=strict) | ||
except json.JSONDecodeError: | ||
pass | ||
|
||
# Initialize variables. | ||
new_s = "" | ||
stack = [] | ||
is_inside_string = False | ||
escaped = False | ||
|
||
# Process each character in the string one at a time. | ||
for char in s: | ||
if is_inside_string: | ||
if char == '"' and not escaped: | ||
is_inside_string = False | ||
elif char == "\n" and not escaped: | ||
char = "\\n" # Replace the newline character with the escape sequence. | ||
elif char == "\\": | ||
escaped = not escaped | ||
else: | ||
escaped = False | ||
else: | ||
if char == '"': | ||
is_inside_string = True | ||
escaped = False | ||
elif char == "{": | ||
stack.append("}") | ||
elif char == "[": | ||
stack.append("]") | ||
elif char == "}" or char == "]": | ||
if stack and stack[-1] == char: | ||
stack.pop() | ||
else: | ||
# Mismatched closing character; the input is malformed. | ||
return None | ||
|
||
# Append the processed character to the new string. | ||
new_s += char | ||
|
||
# If we're still inside a string at the end of processing, | ||
# we need to close the string. | ||
if is_inside_string: | ||
new_s += '"' | ||
|
||
# Close any remaining open structures in the reverse order that they were opened. | ||
for closing_char in reversed(stack): | ||
new_s += closing_char | ||
|
||
# Attempt to parse the modified string as JSON. | ||
try: | ||
return json.loads(new_s, strict=strict) | ||
except json.JSONDecodeError: | ||
# If we still can't parse the string as JSON, return None to indicate failure. | ||
return None | ||
|
||
|
||
def parse_json_markdown( | ||
json_string: str, *, parser: Callable[[str], Any] = json.loads | ||
) -> dict: | ||
""" | ||
Parse a JSON string from a Markdown string. | ||
Args: | ||
json_string: The Markdown string. | ||
Returns: | ||
The parsed JSON object as a Python dictionary. | ||
""" | ||
# Try to find JSON string within triple backticks | ||
match = re.search(r"```(json)?(.*)```", json_string, re.DOTALL) | ||
|
||
# If no match found, assume the entire string is a JSON string | ||
if match is None: | ||
json_str = json_string | ||
else: | ||
# If match found, use the content within the backticks | ||
json_str = match.group(2) | ||
|
||
# Strip whitespace and newlines from the start and end | ||
json_str = json_str.strip() | ||
|
||
# handle newlines and other special characters inside the returned value | ||
json_str = _custom_parser(json_str) | ||
|
||
# Parse the JSON string into a Python dictionary | ||
parsed = parser(json_str) | ||
|
||
return parsed | ||
|
||
|
||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: | ||
""" | ||
Parse a JSON string from a Markdown string and check that it | ||
contains the expected keys. | ||
Args: | ||
text: The Markdown string. | ||
expected_keys: The expected keys in the JSON string. | ||
Returns: | ||
The parsed JSON object as a Python dictionary. | ||
""" | ||
try: | ||
json_obj = parse_json_markdown(text) | ||
except json.JSONDecodeError as e: | ||
raise OutputParserException(f"Got invalid JSON object. Error: {e}") | ||
for key in expected_keys: | ||
if key not in json_obj: | ||
raise OutputParserException( | ||
f"Got invalid return object. Expected key `{key}` " | ||
f"to be present, but got {json_obj}" | ||
) | ||
return json_obj | ||
|
||
|
||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): | ||
"""Parse the output of an LLM call to a JSON object. | ||
When used in streaming mode, it will yield partial JSON objects containing | ||
all the keys that have been returned so far. | ||
In streaming, if `diff` is set to `True`, yields JSONPatch operations | ||
describing the difference between the previous and the current object. | ||
""" | ||
|
||
def _diff(self, prev: Optional[Any], next: Any) -> Any: | ||
return jsonpatch.make_patch(prev, next).patch | ||
|
||
def parse(self, text: str) -> Any: | ||
text = text.strip() | ||
try: | ||
return parse_json_markdown(text.strip(), parser=parse_partial_json) | ||
except JSONDecodeError as e: | ||
raise OutputParserException(f"Invalid json output: {text}") from e | ||
|
||
@property | ||
def _type(self) -> str: | ||
return "simple_json_output_parser" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import re | ||
import xml.etree.ElementTree as ET | ||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union | ||
|
||
from langchain_core.messages import BaseMessage | ||
from langchain_core.output_parsers.transform import BaseTransformOutputParser | ||
from langchain_core.runnables.utils import AddableDict | ||
|
||
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file. | ||
1. Output should conform to the tags below. | ||
2. If tags are not given, make them on your own. | ||
3. Remember to always open and close all the tags. | ||
As an example, for the tags ["foo", "bar", "baz"]: | ||
1. String "<foo>\n <bar>\n <baz></baz>\n </bar>\n</foo>" is a well-formatted instance of the schema. | ||
2. String "<foo>\n <bar>\n </foo>" is a badly-formatted instance. | ||
3. String "<foo>\n <tag>\n </tag>\n</foo>" is a badly-formatted instance. | ||
Here are the output tags: | ||
``` | ||
{tags} | ||
```""" # noqa: E501 | ||
|
||
|
||
class XMLOutputParser(BaseTransformOutputParser): | ||
"""Parse an output using xml format.""" | ||
|
||
tags: Optional[List[str]] = None | ||
encoding_matcher: re.Pattern = re.compile( | ||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL | ||
) | ||
|
||
def get_format_instructions(self) -> str: | ||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) | ||
|
||
def parse(self, text: str) -> Dict[str, List[Any]]: | ||
text = text.strip("`").strip("xml") | ||
encoding_match = self.encoding_matcher.search(text) | ||
if encoding_match: | ||
text = encoding_match.group(2) | ||
|
||
text = text.strip() | ||
if (text.startswith("<") or text.startswith("\n<")) and ( | ||
text.endswith(">") or text.endswith(">\n") | ||
): | ||
root = ET.fromstring(text) | ||
return self._root_to_dict(root) | ||
else: | ||
raise ValueError(f"Could not parse output: {text}") | ||
|
||
def _transform( | ||
self, input: Iterator[Union[str, BaseMessage]] | ||
) -> Iterator[AddableDict]: | ||
parser = ET.XMLPullParser(["start", "end"]) | ||
current_path: List[str] = [] | ||
current_path_has_children = False | ||
for chunk in input: | ||
if isinstance(chunk, BaseMessage): | ||
# extract text | ||
chunk_content = chunk.content | ||
if not isinstance(chunk_content, str): | ||
continue | ||
chunk = chunk_content | ||
# pass chunk to parser | ||
parser.feed(chunk) | ||
# yield all events | ||
for event, elem in parser.read_events(): | ||
if event == "start": | ||
# update current path | ||
current_path.append(elem.tag) | ||
current_path_has_children = False | ||
elif event == "end": | ||
# remove last element from current path | ||
current_path.pop() | ||
# yield element | ||
if not current_path_has_children: | ||
yield nested_element(current_path, elem) | ||
# prevent yielding of parent element | ||
current_path_has_children = True | ||
# close parser | ||
parser.close() | ||
|
||
async def _atransform( | ||
self, input: AsyncIterator[Union[str, BaseMessage]] | ||
) -> AsyncIterator[AddableDict]: | ||
parser = ET.XMLPullParser(["start", "end"]) | ||
current_path: List[str] = [] | ||
current_path_has_children = False | ||
async for chunk in input: | ||
if isinstance(chunk, BaseMessage): | ||
# extract text | ||
chunk_content = chunk.content | ||
if not isinstance(chunk_content, str): | ||
continue | ||
chunk = chunk_content | ||
# pass chunk to parser | ||
parser.feed(chunk) | ||
# yield all events | ||
for event, elem in parser.read_events(): | ||
if event == "start": | ||
# update current path | ||
current_path.append(elem.tag) | ||
current_path_has_children = False | ||
elif event == "end": | ||
# remove last element from current path | ||
current_path.pop() | ||
# yield element | ||
if not current_path_has_children: | ||
yield nested_element(current_path, elem) | ||
# prevent yielding of parent element | ||
current_path_has_children = True | ||
# close parser | ||
parser.close() | ||
|
||
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: | ||
"""Converts xml tree to python dictionary.""" | ||
result: Dict[str, List[Any]] = {root.tag: []} | ||
for child in root: | ||
if len(child) == 0: | ||
result[root.tag].append({child.tag: child.text}) | ||
else: | ||
result[root.tag].append(self._root_to_dict(child)) | ||
return result | ||
|
||
@property | ||
def _type(self) -> str: | ||
return "xml" | ||
|
||
|
||
def nested_element(path: List[str], elem: ET.Element) -> Any: | ||
"""Get nested element from path.""" | ||
if len(path) == 0: | ||
return AddableDict({elem.tag: elem.text}) | ||
else: | ||
return AddableDict({path[0]: [nested_element(path[1:], elem)]}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.