From 26bdf400728e6d8a8f3d8798e0381c1a61586d63 Mon Sep 17 00:00:00 2001 From: Anusha Karkhanis <92644377+anushak18@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:30:57 -0500 Subject: [PATCH] Langchain_Community: SQL LanguageParser (#28430) ## Description (This PR has contributions from @khushiDesai, @ashvini8, and @ssumaiyaahmed). This PR addresses **Issue #11229** which addresses the need for SQL support in document parsing. This is integrated into the generic TreeSitter parsing library, allowing LangChain users to easily load codebases in SQL into smaller, manageable "documents." This pull request adds a new ```SQLSegmenter``` class, which provides the SQL integration. ## Issue **Issue #11229**: Add support for a variety of languages to LanguageParser ## Testing We created a file ```test_sql.py``` with several tests to ensure the ```SQLSegmenter``` is functional. Below are the tests we added: - ```def test_is_valid```: Checks SQL validity. - ```def test_extract_functions_classes```: Extracts individual SQL statements. - ```def test_simplify_code```: Simplifies SQL code with comments. --------- Co-authored-by: Syeda Sumaiya Ahmed <114104419+ssumaiyaahmed@users.noreply.github.com> Co-authored-by: ashvini hunagund <97271381+ashvini8@users.noreply.github.com> Co-authored-by: Khushi Desai Co-authored-by: Khushi Desai <59741309+khushiDesai@users.noreply.github.com> Co-authored-by: ccurme --- .../parsers/language/language_parser.py | 6 +- .../document_loaders/parsers/language/sql.py | 65 +++++++++++++++++++ .../parsers/language/test_sql.py | 61 +++++++++++++++++ 3 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 libs/community/langchain_community/document_loaders/parsers/language/sql.py create mode 100644 libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py diff --git a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py index f44d74e690654..e1d4e5ec664b9 100644 --- a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py +++ b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py @@ -24,6 +24,7 @@ from langchain_community.document_loaders.parsers.language.ruby import RubySegmenter from langchain_community.document_loaders.parsers.language.rust import RustSegmenter from langchain_community.document_loaders.parsers.language.scala import ScalaSegmenter +from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter from langchain_community.document_loaders.parsers.language.typescript import ( TypeScriptSegmenter, ) @@ -47,6 +48,7 @@ "php": "php", "ex": "elixir", "exs": "elixir", + "sql": "sql", } LANGUAGE_SEGMENTERS: Dict[str, Any] = { @@ -67,6 +69,7 @@ "java": JavaSegmenter, "php": PHPSegmenter, "elixir": ElixirSegmenter, + "sql": SQLSegmenter, } Language = Literal[ @@ -83,7 +86,6 @@ "ruby", "rust", "scala", - "swift", "markdown", "latex", "html", @@ -94,6 +96,7 @@ "lua", "perl", "elixir", + "sql", ] @@ -123,6 +126,7 @@ class LanguageParser(BaseBlobParser): - Ruby: "ruby" (*) - Rust: "rust" (*) - Scala: "scala" (*) + - SQL: "sql" (*) - TypeScript: "ts" (*) Items marked with (*) require the packages `tree_sitter` and diff --git a/libs/community/langchain_community/document_loaders/parsers/language/sql.py b/libs/community/langchain_community/document_loaders/parsers/language/sql.py new file mode 100644 index 0000000000000..1c11b7b363758 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/parsers/language/sql.py @@ -0,0 +1,65 @@ +from typing import TYPE_CHECKING + +from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501 + TreeSitterSegmenter, +) + +if TYPE_CHECKING: + from tree_sitter import Language + +CHUNK_QUERY = """ + [ + (create_table_statement) @create + (select_statement) @select + (insert_statement) @insert + (update_statement) @update + (delete_statement) @delete + ] +""" + + +class SQLSegmenter(TreeSitterSegmenter): + """Code segmenter for SQL. + This class uses Tree-sitter to segment SQL code into its + constituent statements (e.g., SELECT, CREATE TABLE). + It also provides functionality to extract these + statements and simplify the code into commented descriptions. + """ + + def get_language(self) -> "Language": + """Return the SQL language grammar for Tree-sitter.""" + from tree_sitter_languages import get_language + + return get_language("sql") + + def get_chunk_query(self) -> str: + """Return the Tree-sitter query for SQL segmentation.""" + return CHUNK_QUERY + + def extract_functions_classes(self) -> list[str]: + """Extract SQL statements from the code. + Ensures that all SQL statements end with a semicolon + for consistency. + """ + extracted = super().extract_functions_classes() + # Ensure all statements end with a semicolon + return [ + stmt.strip() + ";" if not stmt.strip().endswith(";") else stmt.strip() + for stmt in extracted + ] + + def simplify_code(self) -> str: + """Simplify the extracted SQL code into comments. + Converts SQL statements into commented descriptions + for easy readability. + """ + return "\n".join( + [ + f"-- Code for: {stmt.strip()}" + for stmt in self.extract_functions_classes() + ] + ) + + def make_line_comment(self, text: str) -> str: + """Create a line comment in SQL style.""" + return f"-- {text}" diff --git a/libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py new file mode 100644 index 0000000000000..37b22052ea243 --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py @@ -0,0 +1,61 @@ +import unittest + +import pytest + +from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter + + +@pytest.mark.requires("tree_sitter", "tree_sitter_languages") +class TestSQLSegmenter(unittest.TestCase): + """Unit tests for the SQLSegmenter class.""" + + def setUp(self) -> None: + """Set up example code and expected results for testing.""" + self.example_code = """ + CREATE TABLE users (id INT, name TEXT); + + -- A select query + SELECT id, name FROM users WHERE id = 1; + + INSERT INTO users (id, name) VALUES (2, 'Alice'); + + UPDATE users SET name = 'Bob' WHERE id = 2; + + DELETE FROM users WHERE id = 2; + """ + + self.expected_simplified_code = ( + "-- Code for: CREATE TABLE users (id INT, name TEXT);\n" + "-- Code for: SELECT id, name FROM users WHERE id = 1;\n" + "-- Code for: INSERT INTO users (id, name) VALUES (2, 'Alice');\n" + "-- Code for: UPDATE users SET name = 'Bob' WHERE id = 2;\n" + "-- Code for: DELETE FROM users WHERE id = 2;" + ) + + self.expected_extracted_code = [ + "CREATE TABLE users (id INT, name TEXT);", + "SELECT id, name FROM users WHERE id = 1;", + "INSERT INTO users (id, name) VALUES (2, 'Alice');", + "UPDATE users SET name = 'Bob' WHERE id = 2;", + "DELETE FROM users WHERE id = 2;", + ] + + def test_is_valid(self) -> None: + """Test the validity of SQL code.""" + # Valid SQL code should return True + self.assertTrue(SQLSegmenter("SELECT * FROM test").is_valid()) + # Invalid code (non-SQL text) should return False + self.assertFalse(SQLSegmenter("random text").is_valid()) + + def test_extract_functions_classes(self) -> None: + """Test extracting SQL statements from code.""" + segmenter = SQLSegmenter(self.example_code) + extracted_code = segmenter.extract_functions_classes() + # Verify the extracted code matches expected SQL statements + self.assertEqual(extracted_code, self.expected_extracted_code) + + def test_simplify_code(self) -> None: + """Test simplifying SQL code into commented descriptions.""" + segmenter = SQLSegmenter(self.example_code) + simplified_code = segmenter.simplify_code() + self.assertEqual(simplified_code, self.expected_simplified_code)