diff --git a/Cargo.lock b/Cargo.lock index c1590b4cd..9877340d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1004,6 +1004,7 @@ dependencies = [ "postgres-native-tls", "postgres-openssl", "pprof", + "prusto", "r2d2", "r2d2-oracle", "r2d2_mysql", @@ -1014,6 +1015,7 @@ dependencies = [ "rusqlite", "rust_decimal", "rust_decimal_macros", + "serde", "serde_json", "sqlparser 0.37.0", "thiserror", @@ -1061,6 +1063,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "core-foundation" version = "0.9.4" @@ -1493,6 +1501,19 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "derive_utils" version = "0.13.2" @@ -2367,6 +2388,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "iterable" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c151dfd6ab7dff5ca5567d82041bb286f07469ece85c1e2444a6d26d7057a65f" +dependencies = [ + "itertools 0.10.5", +] + [[package]] name = "itertools" version = "0.10.5" @@ -3824,6 +3854,43 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prusto" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4b88a35eb608a69482012e38b818a77c23bd1f3fe952143217609ad6c43f94" +dependencies = [ + "bigdecimal", + "chrono", + "chrono-tz", + "derive_more", + "futures", + "http", + "iterable", + "lazy_static", + "log", + "prusto-macros", + "regex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "urlencoding", + "uuid 1.7.0", +] + +[[package]] +name = "prusto-macros" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "729a73ec40e80da961c846455ec579c521346392d6f9f5a8c8aadfb5c99f9cf8" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -4457,9 +4524,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.195" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" dependencies = [ "serde_derive", ] @@ -4476,9 +4543,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", @@ -4536,6 +4603,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "simdutf8" version = "0.1.4" @@ -5025,6 +5101,7 @@ dependencies = [ "num_cpus", "parking_lot 0.12.1", "pin-project-lite", + "signal-hook-registry", "socket2 0.5.5", "tokio-macros", "windows-sys 0.48.0", @@ -5308,6 +5385,7 @@ checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" dependencies = [ "getrandom 0.2.12", "rand 0.8.5", + "serde", ] [[package]] diff --git a/Justfile b/Justfile index 07ac50601..c086d524a 100644 --- a/Justfile +++ b/Justfile @@ -23,6 +23,7 @@ test-feature-gate: cargo c --features src_oracle cargo c --features src_csv cargo c --features src_dummy + cargo c --features src_trino cargo c --features dst_arrow cargo c --features dst_arrow2 @@ -62,6 +63,7 @@ seed-db-more: ORACLE_URL_SCRIPT=`echo ${ORACLE_URL#oracle://} | sed "s/:/\//"` cat scripts/oracle.sql | sqlplus $ORACLE_URL_SCRIPT mysql --protocol tcp -h$MARIADB_HOST -P$MARIADB_PORT -u$MARIADB_USER -p$MARIADB_PASSWORD $MARIADB_DB < scripts/mysql.sql + trino $TRINO_URL --catalog=$TRINO_CATALOG < scripts/trino.sql # benches flame-tpch conn="POSTGRES_URL": diff --git a/README.md b/README.md index aa6539fde..a711e2130 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ For more planned data sources, please check out our [discussion](https://github. - [x] Oracle - [x] Big Query - [ ] ODBC (WIP) +- [ ] Trino (WIP) - [ ] ... ## Destinations diff --git a/connectorx-cpp/Cargo.toml b/connectorx-cpp/Cargo.toml index a2eaa8b8c..1ecd6958b 100644 --- a/connectorx-cpp/Cargo.toml +++ b/connectorx-cpp/Cargo.toml @@ -34,4 +34,5 @@ srcs = [ "connectorx/src_mssql", "connectorx/src_oracle", "connectorx/src_bigquery", + "connectorx/src_trino", ] diff --git a/connectorx-python/Cargo.lock b/connectorx-python/Cargo.lock index d6e48d681..4303ca018 100644 --- a/connectorx-python/Cargo.lock +++ b/connectorx-python/Cargo.lock @@ -1048,6 +1048,7 @@ dependencies = [ "postgres", "postgres-native-tls", "postgres-openssl", + "prusto", "r2d2", "r2d2-oracle", "r2d2_mysql", @@ -1057,6 +1058,7 @@ dependencies = [ "rusqlite", "rust_decimal", "rust_decimal_macros", + "serde", "serde_json", "sqlparser 0.37.0", "thiserror", @@ -1139,6 +1141,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "core-foundation" version = "0.9.3" @@ -1594,6 +1602,19 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "derive_utils" version = "0.13.2" @@ -2537,6 +2558,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +[[package]] +name = "iterable" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c151dfd6ab7dff5ca5567d82041bb286f07469ece85c1e2444a6d26d7057a65f" +dependencies = [ + "itertools 0.10.5", +] + [[package]] name = "itertools" version = "0.10.5" @@ -4053,6 +4083,43 @@ dependencies = [ "prost", ] +[[package]] +name = "prusto" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4b88a35eb608a69482012e38b818a77c23bd1f3fe952143217609ad6c43f94" +dependencies = [ + "bigdecimal", + "chrono", + "chrono-tz", + "derive_more", + "futures", + "http", + "iterable", + "lazy_static", + "log", + "prusto-macros", + "regex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "urlencoding", + "uuid 1.4.1", +] + +[[package]] +name = "prusto-macros" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "729a73ec40e80da961c846455ec579c521346392d6f9f5a8c8aadfb5c99f9cf8" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -5300,6 +5367,7 @@ dependencies = [ "num_cpus", "parking_lot 0.12.1", "pin-project-lite", + "signal-hook-registry", "socket2 0.5.3", "tokio-macros", "windows-sys", @@ -5576,6 +5644,7 @@ checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" dependencies = [ "getrandom 0.2.10", "rand 0.8.5", + "serde", ] [[package]] diff --git a/connectorx-python/Cargo.toml b/connectorx-python/Cargo.toml index 4ed86b30e..37ba3394b 100644 --- a/connectorx-python/Cargo.toml +++ b/connectorx-python/Cargo.toml @@ -75,5 +75,6 @@ srcs = [ "connectorx/src_mssql", "connectorx/src_oracle", "connectorx/src_bigquery", + "connectorx/src_trino", ] integrated-auth-gssapi = ["connectorx/integrated-auth-gssapi"] diff --git a/connectorx-python/connectorx/tests/test_trino.py b/connectorx-python/connectorx/tests/test_trino.py new file mode 100644 index 000000000..783d2b99e --- /dev/null +++ b/connectorx-python/connectorx/tests/test_trino.py @@ -0,0 +1,364 @@ +import os + +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from .. import read_sql + + +@pytest.fixture(scope="module") # type: ignore +def trino_url() -> str: + conn = os.environ["TRINO_URL"] + return conn + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_without_partition(trino_url: str) -> None: + query = "select * from test.test_table order by test_int limit 3" + df = read_sql(trino_url, query) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_null": pd.Series([None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_with_partition(trino_url: str) -> None: + query = "select * from test.test_table order by test_int" + df = read_sql( + trino_url, + query, + partition_on="test_int", + partition_range=(0, 10), + partition_num=6, + ) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_without_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int" + df = read_sql(trino_url, query) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_limit_without_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 3" + df = read_sql(trino_url, query) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_null": pd.Series([None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_limit_large_without_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 10" + df = read_sql(trino_url, query) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_with_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int" + df = read_sql( + trino_url, + query, + partition_on="test_int", + partition_range=(0, 2000), + partition_num=3, + ) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_limit_with_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 3" + df = read_sql( + trino_url, + query, + partition_on="test_int", + partition_range=(0, 2000), + partition_num=3, + ) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_null": pd.Series([None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_limit_large_with_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 10" + df = read_sql( + trino_url, + query, + partition_on="test_int", + partition_range=(0, 2000), + partition_num=3, + ) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_with_partition_without_partition_range(trino_url: str) -> None: + query = "SELECT * FROM test.test_table where test_float > 3" + df = read_sql( + trino_url, + query, + partition_on="test_int", + partition_num=3, + ) + expected = pd.DataFrame( + index=range(4), + data={ + "test_int": pd.Series([3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_null": pd.Series([None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_manual_partition(trino_url: str) -> None: + queries = [ + "SELECT * FROM test.test_table WHERE test_int < 2 order by test_int", + "SELECT * FROM test.test_table WHERE test_int >= 2 order by test_int", + ] + df = read_sql(trino_url, query=queries) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_selection_and_projection(trino_url: str) -> None: + query = "SELECT test_int FROM test.test_table WHERE test_float < 5 order by test_int" + df = read_sql( + trino_url, + query, + partition_on="test_int", + partition_num=3, + ) + expected = pd.DataFrame( + index=range(4), + data={ + "test_int": pd.Series([1, 2, 3, 4], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_join(trino_url: str) -> None: + query = "SELECT T.test_int, T.test_float, S.test_str FROM test.test_table T INNER JOIN test.test_table_extra S ON T.test_int = S.test_int order by T.test_int" + df = read_sql( + trino_url, + query, + partition_on="test_int", + partition_num=3, + ) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_str": pd.Series( + [ + "Ha好ち😁ðy̆", + "こんにちは", + "русский", + ], + dtype="object", + ), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_aggregate(trino_url: str) -> None: + query = "select AVG(test_float) as avg_float, SUM(T.test_int) as sum_int, SUM(test_null) as sum_null from test.test_table as T" + df = read_sql(trino_url, query) + expected = pd.DataFrame( + index=range(1), + data={ + "avg_float": pd.Series([3.85], dtype="float64"), + "sum_int": pd.Series([21], dtype="Int64"), + "sum_null": pd.Series([None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_trino_types_binary(trino_url: str) -> None: + query = "select test_boolean, test_int, test_bigint, test_real, test_double, test_decimal, test_date, test_time, test_timestamp, test_varchar, test_uuid from test.test_types order by test_int" + df = read_sql(trino_url, query) + expected = pd.DataFrame( + index=range(3), + data={ + "test_boolean": pd.Series([True, False, None], dtype="boolean"), + "test_int": pd.Series([123, 321, None], dtype="Int64"), + "test_bigint": pd.Series([1000, 2000, None], dtype="Int64"), + "test_real": pd.Series([123.456, 123.456, None], dtype="float64"), + "test_double": pd.Series([123.4567890123, 123.4567890123, None], dtype="float64"), + "test_decimal": pd.Series([1234567890.12, 1234567890.12, None], dtype="float64"), + "test_date": pd.Series(["2023-01-01", "2023-01-01", None], dtype="datetime64[ns]"), + "test_time": pd.Series(["12:00:00", "12:00:00", None], dtype="object"), + "test_timestamp": pd.Series(["2023-01-01 12:00:00.123456", "2023-01-01 12:00:00.123456", None], dtype="datetime64[ns]"), + "test_varchar": pd.Series(["Sample text", "Sample text", None], dtype="object"), + "test_uuid": pd.Series(["f4967dbb-33e9-4242-a13a-45b56ce60dba", "1c8b79d0-4508-4974-b728-7651bce4a5a5", None], dtype="object"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_empty_result(trino_url: str) -> None: + query = "SELECT * FROM test.test_table where test_int < -100" + df = read_sql(trino_url, query) + expected = pd.DataFrame( + data={ + "test_int": pd.Series([], dtype="Int64"), + "test_float": pd.Series([], dtype="float64"), + "test_null": pd.Series([], dtype="Int64"), + } + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_empty_result_on_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table where test_int < -100" + df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) + expected = pd.DataFrame( + data={ + "test_int": pd.Series([], dtype="Int64"), + "test_float": pd.Series([], dtype="float64"), + "test_null": pd.Series([], dtype="Int64"), + } + ) + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) +def test_empty_result_on_some_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table where test_int = 6" + df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) + expected = pd.DataFrame( + index=range(1), + data={ + "test_int": pd.Series([6], dtype="Int64"), + "test_float": pd.Series([6.6], dtype="float64"), + "test_null": pd.Series([None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) diff --git a/connectorx-python/src/errors.rs b/connectorx-python/src/errors.rs index a8754ef2f..929023e05 100644 --- a/connectorx-python/src/errors.rs +++ b/connectorx-python/src/errors.rs @@ -42,6 +42,9 @@ pub enum ConnectorXPythonError { #[error(transparent)] BigQuerySourceError(#[from] connectorx::sources::bigquery::BigQuerySourceError), + #[error(transparent)] + TrinoSourceError(#[from] connectorx::sources::trino::TrinoSourceError), + #[error(transparent)] ArrowDestinationError(#[from] connectorx::destinations::arrow::ArrowDestinationError), diff --git a/connectorx-python/src/pandas/get_meta.rs b/connectorx-python/src/pandas/get_meta.rs index bc5e7de95..7ee648e7d 100644 --- a/connectorx-python/src/pandas/get_meta.rs +++ b/connectorx-python/src/pandas/get_meta.rs @@ -2,7 +2,7 @@ use super::{ destination::PandasDestination, transports::{ BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport, - PostgresPandasTransport, SqlitePandasTransport, + PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport, }, }; use crate::errors::ConnectorXPythonError; @@ -18,6 +18,7 @@ use connectorx::{ PostgresSource, SimpleProtocol, }, sqlite::SQLiteSource, + trino::TrinoSource, }, sql::CXQuery, }; @@ -223,6 +224,17 @@ pub fn get_meta<'a>(py: Python<'a>, conn: &str, protocol: &str, query: String) - debug!("Running dispatcher"); dispatcher.get_meta()?; } + SourceType::Trino => { + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let source = TrinoSource::new(rt, &source_conn.conn[..])?; + let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new( + source, + &mut destination, + queries, + None, + ); + dispatcher.run()?; + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), } diff --git a/connectorx-python/src/pandas/mod.rs b/connectorx-python/src/pandas/mod.rs index be2e41928..117280866 100644 --- a/connectorx-python/src/pandas/mod.rs +++ b/connectorx-python/src/pandas/mod.rs @@ -8,7 +8,7 @@ mod typesystem; pub use self::destination::{PandasBlockInfo, PandasDestination, PandasPartitionDestination}; pub use self::transports::{ BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport, - PostgresPandasTransport, SqlitePandasTransport, + PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport, }; pub use self::typesystem::{PandasDType, PandasTypeSystem}; use crate::errors::ConnectorXPythonError; @@ -230,6 +230,17 @@ pub fn write_pandas<'a>( ); dispatcher.run()?; } + SourceType::Trino => { + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let source = TrinoSource::new(rt, &source_conn.conn[..])?; + let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new( + source, + &mut destination, + queries, + origin_query, + ); + dispatcher.run()?; + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), } diff --git a/connectorx-python/src/pandas/transports/mod.rs b/connectorx-python/src/pandas/transports/mod.rs index 9f03abf33..fbf7952fb 100644 --- a/connectorx-python/src/pandas/transports/mod.rs +++ b/connectorx-python/src/pandas/transports/mod.rs @@ -4,6 +4,7 @@ mod mysql; mod oracle; mod postgres; mod sqlite; +mod trino; pub use self::postgres::PostgresPandasTransport; pub use bigquery::BigQueryPandasTransport; @@ -11,3 +12,4 @@ pub use mssql::MsSQLPandasTransport; pub use mysql::MysqlPandasTransport; pub use oracle::OraclePandasTransport; pub use sqlite::SqlitePandasTransport; +pub use trino::TrinoPandasTransport; diff --git a/connectorx-python/src/pandas/transports/trino.rs b/connectorx-python/src/pandas/transports/trino.rs new file mode 100644 index 000000000..fba7a06d5 --- /dev/null +++ b/connectorx-python/src/pandas/transports/trino.rs @@ -0,0 +1,54 @@ +use crate::errors::ConnectorXPythonError; +use crate::pandas::destination::PandasDestination; +use crate::pandas::typesystem::PandasTypeSystem; +use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use connectorx::{ + impl_transport, + sources::trino::{TrinoSource, TrinoTypeSystem}, + typesystem::TypeConversion, +}; + +pub struct TrinoPandasTransport<'py>(&'py ()); + +impl_transport!( + name = TrinoPandasTransport<'tp>, + error = ConnectorXPythonError, + systems = TrinoTypeSystem => PandasTypeSystem, + route = TrinoSource => PandasDestination<'tp>, + mappings = { + { Date[NaiveDate] => DateTime[DateTime] | conversion option } + { Time[NaiveTime] => String[String] | conversion option } + { Timestamp[NaiveDateTime] => DateTime[DateTime] | conversion option } + { Boolean[bool] => Bool[bool] | conversion auto } + { Bigint[i32] => I64[i64] | conversion auto } + { Integer[i32] => I64[i64] | conversion none } + { Smallint[i16] => I64[i64] | conversion auto } + { Tinyint[i8] => I64[i64] | conversion auto } + { Double[f64] => F64[f64] | conversion auto } + { Real[f32] => F64[f64] | conversion auto } + { Varchar[String] => String[String] | conversion auto } + { Char[String] => String[String] | conversion none } + } +); + +impl<'py> TypeConversion> for TrinoPandasTransport<'py> { + fn convert(val: NaiveDate) -> DateTime { + DateTime::from_naive_utc_and_offset( + val.and_hms_opt(0, 0, 0) + .unwrap_or_else(|| panic!("and_hms_opt got None from {:?}", val)), + Utc, + ) + } +} + +impl<'py> TypeConversion for TrinoPandasTransport<'py> { + fn convert(val: NaiveTime) -> String { + val.to_string() + } +} + +impl<'py> TypeConversion> for TrinoPandasTransport<'py> { + fn convert(val: NaiveDateTime) -> DateTime { + DateTime::from_naive_utc_and_offset(val, Utc) + } +} diff --git a/connectorx/Cargo.toml b/connectorx/Cargo.toml index 7918981bd..f7b200e39 100644 --- a/connectorx/Cargo.toml +++ b/connectorx/Cargo.toml @@ -57,6 +57,8 @@ urlencoding = {version = "2.1", optional = true} uuid = {version = "0.8", optional = true} j4rs = {version = "0.15", optional = true} datafusion = {version = "31", optional = true} +prusto = {version = "0.5.1", optional = true} +serde = {optional = true} [lib] crate-type = ["cdylib", "rlib"] @@ -69,7 +71,7 @@ iai = "0.1" pprof = {version = "0.5", features = ["flamegraph"]} [features] -all = ["src_sqlite", "src_postgres", "src_mysql", "src_mssql", "src_oracle", "src_bigquery", "src_csv", "src_dummy", "dst_arrow", "dst_arrow2", "federation", "fed_exec"] +all = ["src_sqlite", "src_postgres", "src_mysql", "src_mssql", "src_oracle", "src_bigquery", "src_csv", "src_dummy", "src_trino", "dst_arrow", "dst_arrow2", "federation", "fed_exec"] branch = [] default = ["fptr"] dst_arrow = ["arrow"] @@ -97,6 +99,7 @@ src_postgres = [ "postgres-openssl", ] src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"] +src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits", "serde"] federation = ["j4rs"] fed_exec = ["datafusion", "tokio"] integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"] diff --git a/connectorx/src/lib.rs b/connectorx/src/lib.rs index 84b043be8..5b4ce7386 100644 --- a/connectorx/src/lib.rs +++ b/connectorx/src/lib.rs @@ -208,6 +208,8 @@ pub mod prelude { pub use crate::sources::postgres::PostgresSource; #[cfg(feature = "src_sqlite")] pub use crate::sources::sqlite::SQLiteSource; + #[cfg(feature = "src_trino")] + pub use crate::sources::trino::TrinoSource; pub use crate::sources::{PartitionParser, Produce, Source, SourcePartition}; pub use crate::sql::CXQuery; pub use crate::transports::*; diff --git a/connectorx/src/partition.rs b/connectorx/src/partition.rs index 370120e2a..fedd34fe7 100644 --- a/connectorx/src/partition.rs +++ b/connectorx/src/partition.rs @@ -10,6 +10,7 @@ use crate::sources::mysql::{MySQLSourceError, MySQLTypeSystem}; use crate::sources::oracle::{connect_oracle, OracleDialect}; #[cfg(feature = "src_postgres")] use crate::sources::postgres::{rewrite_tls_args, PostgresTypeSystem}; +use crate::sources::trino::TrinoDialect; #[cfg(feature = "src_sqlite")] use crate::sql::get_partition_range_query_sep; use crate::sql::{get_partition_range_query, single_col_partition_query, CXQuery}; @@ -35,7 +36,7 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::dialect::SQLiteDialect; #[cfg(feature = "src_mssql")] use tiberius::Client; -#[cfg(any(feature = "src_bigquery", feature = "src_mssql"))] +#[cfg(any(feature = "src_bigquery", feature = "src_mssql", feature = "src_trino"))] use tokio::{net::TcpStream, runtime::Runtime}; #[cfg(feature = "src_mssql")] use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -100,6 +101,8 @@ pub fn get_col_range(source_conn: &SourceConn, query: &str, col: &str) -> OutRes SourceType::Oracle => oracle_get_partition_range(&source_conn.conn, query, col), #[cfg(feature = "src_bigquery")] SourceType::BigQuery => bigquery_get_partition_range(&source_conn.conn, query, col), + #[cfg(feature = "src_trino")] + SourceType::Trino => trino_get_partition_range(&source_conn.conn, query, col), _ => unimplemented!("{:?} not implemented!", source_conn.ty), } } @@ -137,6 +140,10 @@ pub fn get_part_query( SourceType::BigQuery => { single_col_partition_query(query, col, lower, upper, &BigQueryDialect {})? } + #[cfg(feature = "src_trino")] + SourceType::Trino => { + single_col_partition_query(query, col, lower, upper, &TrinoDialect {})? + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), }; CXQuery::Wrapped(query) @@ -481,3 +488,52 @@ fn bigquery_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64 (min_v, max_v) } + +#[cfg(feature = "src_trino")] +#[throws(ConnectorXOutError)] +fn trino_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) { + use prusto::{auth::Auth, ClientBuilder}; + + use crate::sources::trino::{TrinoDialect, TrinoPartitionQueryResult}; + + let rt = Runtime::new().expect("Failed to create runtime"); + + let username = match conn.username() { + "" => "connectorx", + username => username, + }; + + let builder = ClientBuilder::new(username, conn.host().unwrap().to_owned()) + .port(conn.port().unwrap_or(8080)) + .ssl(prusto::ssl::Ssl { root_cert: None }) + .secure(conn.scheme() == "trino+https") + .catalog(conn.path_segments().unwrap().last().unwrap_or("hive")); + + let builder = match conn.password() { + None => builder, + Some(password) => builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))), + }; + + let client = builder + .build() + .map_err(|e| anyhow!("Failed to build client: {}", e))?; + + let range_query = get_partition_range_query(query, col, &TrinoDialect {})?; + let query_result = rt.block_on(client.get_all::(range_query)); + + let query_result = match query_result { + Ok(query_result) => Ok(query_result.into_vec()), + Err(e) => match e { + prusto::error::Error::EmptyData => { + Ok(vec![TrinoPartitionQueryResult { _col0: 0, _col1: 0 }]) + } + _ => Err(anyhow!("Failed to get query result: {}", e)), + }, + }?; + + let result = query_result + .first() + .unwrap_or(&TrinoPartitionQueryResult { _col0: 0, _col1: 0 }); + + (result._col0, result._col1) +} diff --git a/connectorx/src/source_router.rs b/connectorx/src/source_router.rs index d30796766..ad3aec489 100644 --- a/connectorx/src/source_router.rs +++ b/connectorx/src/source_router.rs @@ -14,6 +14,7 @@ pub enum SourceType { Oracle, BigQuery, DuckDB, + Trino, Unknown, } @@ -58,6 +59,7 @@ impl TryFrom<&str> for SourceConn { "oracle" => Ok(SourceConn::new(SourceType::Oracle, url, proto)), "bigquery" => Ok(SourceConn::new(SourceType::BigQuery, url, proto)), "duckdb" => Ok(SourceConn::new(SourceType::DuckDB, url, proto)), + "trino" => Ok(SourceConn::new(SourceType::Trino, url, proto)), _ => Ok(SourceConn::new(SourceType::Unknown, url, proto)), } } diff --git a/connectorx/src/sources/mod.rs b/connectorx/src/sources/mod.rs index 0afb6416c..dd86b524d 100644 --- a/connectorx/src/sources/mod.rs +++ b/connectorx/src/sources/mod.rs @@ -17,6 +17,8 @@ pub mod oracle; pub mod postgres; #[cfg(feature = "src_sqlite")] pub mod sqlite; +#[cfg(feature = "src_trino")] +pub mod trino; use crate::data_order::DataOrder; use crate::errors::ConnectorXError; diff --git a/connectorx/src/sources/trino/errors.rs b/connectorx/src/sources/trino/errors.rs new file mode 100644 index 000000000..8b46eff19 --- /dev/null +++ b/connectorx/src/sources/trino/errors.rs @@ -0,0 +1,25 @@ +use std::string::FromUtf8Error; + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoSourceError { + #[error("Cannot infer type from null for Trino")] + InferTypeFromNull, + + #[error(transparent)] + ConnectorXError(#[from] crate::errors::ConnectorXError), + + #[error(transparent)] + PrustoError(prusto::error::Error), + + #[error(transparent)] + UrlParseError(#[from] url::ParseError), + + #[error(transparent)] + TrinoUrlDecodeError(#[from] FromUtf8Error), + + /// Any other errors that are too trivial to be put here explicitly. + #[error(transparent)] + Other(#[from] anyhow::Error), +} diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs new file mode 100644 index 000000000..072b0ed6f --- /dev/null +++ b/connectorx/src/sources/trino/mod.rs @@ -0,0 +1,663 @@ +use std::{marker::PhantomData, sync::Arc}; + +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use fehler::{throw, throws}; +use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row}; +use serde_json::Value; +use sqlparser::dialect::{Dialect, GenericDialect}; +use std::convert::TryFrom; +use tokio::runtime::Runtime; + +use crate::{ + data_order::DataOrder, + errors::ConnectorXError, + sources::Produce, + sql::{count_query, limit1_query, CXQuery}, +}; + +pub use self::{errors::TrinoSourceError, typesystem::TrinoTypeSystem}; +use urlencoding::decode; + +use super::{PartitionParser, Source, SourcePartition}; + +use anyhow::anyhow; + +pub mod errors; +pub mod typesystem; + +#[throws(TrinoSourceError)] +fn get_total_rows(rt: Arc, client: Arc, query: &CXQuery) -> usize { + let cquery = count_query(query, &TrinoDialect {})?; + + let row = rt + .block_on(client.get_all::(cquery.to_string())) + .map_err(TrinoSourceError::PrustoError)? + .split() + .1[0] + .clone(); + + let value = row + .value() + .first() + .ok_or_else(|| anyhow!("Trino count dataset is empty"))?; + + value + .as_i64() + .ok_or_else(|| anyhow!("Trino cannot parse i64"))? as usize +} + +#[derive(Presto, Debug)] +pub struct TrinoPartitionQueryResult { + pub _col0: i64, + pub _col1: i64, +} + +#[derive(Debug)] +pub struct TrinoDialect {} + +// implementation copy from AnsiDialect +impl Dialect for TrinoDialect { + fn is_identifier_start(&self, ch: char) -> bool { + ch.is_ascii_lowercase() || ch.is_ascii_uppercase() + } + + fn is_identifier_part(&self, ch: char) -> bool { + ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_' + } +} + +pub struct TrinoSource { + client: Arc, + rt: Arc, + origin_query: Option, + queries: Vec>, + names: Vec, + schema: Vec, +} + +impl TrinoSource { + #[throws(TrinoSourceError)] + pub fn new(rt: Arc, conn: &str) -> Self { + let decoded_conn = decode(conn)?.into_owned(); + + let url = decoded_conn + .parse::() + .map_err(TrinoSourceError::UrlParseError)?; + + let username = match url.username() { + "" => "connectorx", + username => username, + }; + + let builder = ClientBuilder::new(username, url.host().unwrap().to_owned()) + .port(url.port().unwrap_or(8080)) + .ssl(prusto::ssl::Ssl { root_cert: None }) + .secure(url.scheme() == "trino+https") + .catalog(url.path_segments().unwrap().last().unwrap_or("hive")); + + let builder = match url.password() { + None => builder, + Some(password) => { + builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))) + } + }; + + let client = builder.build().map_err(TrinoSourceError::PrustoError)?; + + Self { + client: Arc::new(client), + rt, + origin_query: None, + queries: vec![], + names: vec![], + schema: vec![], + } + } +} + +impl Source for TrinoSource +where + TrinoSourcePartition: SourcePartition, +{ + const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor]; + type TypeSystem = TrinoTypeSystem; + type Partition = TrinoSourcePartition; + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn set_data_order(&mut self, data_order: DataOrder) { + if !matches!(data_order, DataOrder::RowMajor) { + throw!(ConnectorXError::UnsupportedDataOrder(data_order)); + } + } + + fn set_queries(&mut self, queries: &[CXQuery]) { + self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect(); + } + + fn set_origin_query(&mut self, query: Option) { + self.origin_query = query; + } + + #[throws(TrinoSourceError)] + fn fetch_metadata(&mut self) { + assert!(!self.queries.is_empty()); + + let first_query = &self.queries[0]; + let cxq = limit1_query(first_query, &GenericDialect {})?; + + let dataset: DataSet = self + .rt + .block_on(self.client.get_all::(cxq.to_string())) + .map_err(TrinoSourceError::PrustoError)?; + + let schema = dataset.split().0; + + for (name, t) in schema { + self.names.push(name.clone()); + self.schema.push(TrinoTypeSystem::try_from(t.clone())?); + } + } + + #[throws(TrinoSourceError)] + fn result_rows(&mut self) -> Option { + match &self.origin_query { + Some(q) => { + let cxq = CXQuery::Naked(q.clone()); + let nrows = get_total_rows(self.rt.clone(), self.client.clone(), &cxq)?; + Some(nrows) + } + None => None, + } + } + + fn names(&self) -> Vec { + self.names.clone() + } + + fn schema(&self) -> Vec { + self.schema.clone() + } + + #[throws(TrinoSourceError)] + fn partition(self) -> Vec { + let mut ret = vec![]; + + for query in self.queries { + ret.push(TrinoSourcePartition::new( + self.client.clone(), + query, + self.schema.clone(), + self.rt.clone(), + )?); + } + ret + } +} + +pub struct TrinoSourcePartition { + client: Arc, + query: CXQuery, + schema: Vec, + rt: Arc, + nrows: usize, +} + +impl TrinoSourcePartition { + #[throws(TrinoSourceError)] + pub fn new( + client: Arc, + query: CXQuery, + schema: Vec, + rt: Arc, + ) -> Self { + Self { + client, + query: query.clone(), + schema: schema.to_vec(), + rt, + nrows: 0, + } + } +} + +impl SourcePartition for TrinoSourcePartition { + type TypeSystem = TrinoTypeSystem; + type Parser<'a> = TrinoSourcePartitionParser<'a>; + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn result_rows(&mut self) { + self.nrows = get_total_rows(self.rt.clone(), self.client.clone(), &self.query)?; + } + + #[throws(TrinoSourceError)] + fn parser(&mut self) -> Self::Parser<'_> { + TrinoSourcePartitionParser::new( + self.rt.clone(), + self.client.clone(), + self.query.clone(), + &self.schema, + )? + } + + fn nrows(&self) -> usize { + self.nrows + } + + fn ncols(&self) -> usize { + self.schema.len() + } +} + +pub struct TrinoSourcePartitionParser<'a> { + rt: Arc, + client: Arc, + next_uri: Option, + rows: Vec, + ncols: usize, + current_col: usize, + current_row: usize, + _phantom: &'a PhantomData>, +} + +impl<'a> TrinoSourcePartitionParser<'a> { + #[throws(TrinoSourceError)] + pub fn new( + rt: Arc, + client: Arc, + query: CXQuery, + schema: &[TrinoTypeSystem], + ) -> Self { + let results = rt + .block_on(client.get::(query.to_string())) + .map_err(TrinoSourceError::PrustoError)?; + + let rows = match results.data_set { + Some(x) => x.into_vec(), + _ => vec![], + }; + + Self { + rt, + client, + next_uri: results.next_uri, + rows, + ncols: schema.len(), + current_row: 0, + current_col: 0, + _phantom: &PhantomData, + } + } + + #[throws(TrinoSourceError)] + fn next_loc(&mut self) -> (usize, usize) { + let ret = (self.current_row, self.current_col); + self.current_row += (self.current_col + 1) / self.ncols; + self.current_col = (self.current_col + 1) % self.ncols; + ret + } +} + +impl<'a> PartitionParser<'a> for TrinoSourcePartitionParser<'a> { + type TypeSystem = TrinoTypeSystem; + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn fetch_next(&mut self) -> (usize, bool) { + assert!(self.current_col == 0); + + match self.next_uri.clone() { + Some(uri) => { + let results = self + .rt + .block_on(self.client.get_next::(&uri)) + .map_err(TrinoSourceError::PrustoError)?; + + self.rows = match results.data_set { + Some(x) => x.into_vec(), + _ => vec![], + }; + + self.current_row = 0; + self.next_uri = results.next_uri; + + (self.rows.len(), false) + } + None => return (self.rows.len(), true), + } + } +} + +macro_rules! impl_produce_int { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Number(x) => { + if (x.is_i64()) { + <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))? + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) + } + } + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::Number(x) => { + if (x.is_i64()) { + Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?) + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) + } + } + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_float { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Number(x) => { + if (x.is_f64()) { + x.as_f64().unwrap() as $t + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) + } + } + Value::String(x) => x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?, + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::Number(x) => { + if (x.is_f64()) { + Some(x.as_f64().unwrap() as $t) + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) + } + } + Value::String(x) => Some(x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?), + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_text { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => { + x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))? + } + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => { + Some(x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?) + } + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_timestamp { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?, + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => Some(NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?), + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_bool { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Bool(x) => *x, + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::Bool(x) => Some(*x), + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + )+ + }; +} + +impl_produce_bool!(bool,); +impl_produce_int!(i8, i16, i32, i64,); +impl_produce_float!(f32, f64,); +impl_produce_timestamp!(NaiveDateTime,); +impl_produce_text!(String, char,); + +impl<'r, 'a> Produce<'r, NaiveTime> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> NaiveTime { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| { + anyhow!( + "Trino cannot parse String at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?, + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} + +impl<'r, 'a> Produce<'r, Option> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => { + Some(NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| { + anyhow!( + "Trino cannot parse Time at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?) + } + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} + +impl<'r, 'a> Produce<'r, NaiveDate> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> NaiveDate { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| { + anyhow!( + "Trino cannot parse Date at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?, + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} + +impl<'r, 'a> Produce<'r, Option> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => Some(NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| { + anyhow!( + "Trino cannot parse Date at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?), + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} diff --git a/connectorx/src/sources/trino/typesystem.rs b/connectorx/src/sources/trino/typesystem.rs new file mode 100644 index 000000000..c21c8cf45 --- /dev/null +++ b/connectorx/src/sources/trino/typesystem.rs @@ -0,0 +1,109 @@ +use super::errors::TrinoSourceError; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use fehler::{throw, throws}; +use prusto::{PrestoFloat, PrestoInt, PrestoTy}; +use std::convert::TryFrom; + +// TODO: implement Tuple, Row, Array and Map +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum TrinoTypeSystem { + Date(bool), + Time(bool), + Timestamp(bool), + Boolean(bool), + Bigint(bool), + Integer(bool), + Smallint(bool), + Tinyint(bool), + Double(bool), + Real(bool), + Varchar(bool), + Char(bool), +} + +impl_typesystem! { + system = TrinoTypeSystem, + mappings = { + { Date => NaiveDate } + { Time => NaiveTime } + { Timestamp => NaiveDateTime } + { Boolean => bool } + { Bigint => i64 } + { Integer => i32 } + { Smallint => i16 } + { Tinyint => i8 } + { Double => f64 } + { Real => f32 } + { Varchar => String } + { Char => char } + } +} + +impl TryFrom for TrinoTypeSystem { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn try_from(ty: PrestoTy) -> Self { + use TrinoTypeSystem::*; + match ty { + PrestoTy::Date => Date(true), + PrestoTy::Time => Time(true), + PrestoTy::Timestamp => Timestamp(true), + PrestoTy::Boolean => Boolean(true), + PrestoTy::PrestoInt(PrestoInt::I64) => Bigint(true), + PrestoTy::PrestoInt(PrestoInt::I32) => Integer(true), + PrestoTy::PrestoInt(PrestoInt::I16) => Smallint(true), + PrestoTy::PrestoInt(PrestoInt::I8) => Tinyint(true), + PrestoTy::PrestoFloat(PrestoFloat::F64) => Double(true), + PrestoTy::PrestoFloat(PrestoFloat::F32) => Real(true), + PrestoTy::Varchar => Varchar(true), + PrestoTy::Char(_) => Char(true), + PrestoTy::Tuple(_) => Varchar(true), + PrestoTy::Row(_) => Varchar(true), + PrestoTy::Array(_) => Varchar(true), + PrestoTy::Map(_, _) => Varchar(true), + PrestoTy::Decimal(_, _) => Double(true), + PrestoTy::IpAddress => Varchar(true), + PrestoTy::Uuid => Varchar(true), + _ => throw!(TrinoSourceError::InferTypeFromNull), + } + } +} + +impl TryFrom<(Option<&str>, PrestoTy)> for TrinoTypeSystem { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn try_from(types: (Option<&str>, PrestoTy)) -> Self { + use TrinoTypeSystem::*; + match types { + (Some(decl_type), ty) => { + let decl_type = decl_type.to_lowercase(); + match decl_type.as_str() { + "date" => Date(true), + "time" => Time(true), + "timestamp" => Timestamp(true), + "boolean" => Boolean(true), + "bigint" => Bigint(true), + "int" | "integer" => Integer(true), + "smallint" => Smallint(true), + "tinyint" => Tinyint(true), + "double" => Double(true), + "real" | "float" => Real(true), + "varchar" | "varbinary" | "json" => Varchar(true), + "char" => Char(true), + "tuple" => Varchar(true), + "row" => Varchar(true), + "array" => Varchar(true), + "map" => Varchar(true), + "decimal" => Double(true), + "ipaddress" => Varchar(true), + "uuid" => Varchar(true), + _ => TrinoTypeSystem::try_from(ty)?, + } + } + // derive from value type directly if no declare type available + (None, ty) => TrinoTypeSystem::try_from(ty)?, + } + } +} diff --git a/connectorx/src/transports/mod.rs b/connectorx/src/transports/mod.rs index 8be61dc2c..96f90db44 100644 --- a/connectorx/src/transports/mod.rs +++ b/connectorx/src/transports/mod.rs @@ -44,7 +44,12 @@ mod sqlite_arrow; mod sqlite_arrow2; #[cfg(all(feature = "src_sqlite", feature = "dst_arrow"))] mod sqlite_arrowstream; - +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +mod trino_arrow; +#[cfg(all(feature = "src_trino", feature = "dst_arrow2"))] +mod trino_arrow2; +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +mod trino_arrowstream; #[cfg(all(feature = "src_bigquery", feature = "dst_arrow"))] pub use bigquery_arrow::{BigQueryArrowTransport, BigQueryArrowTransportError}; #[cfg(all(feature = "src_bigquery", feature = "dst_arrow2"))] @@ -105,3 +110,12 @@ pub use sqlite_arrowstream::{ SQLiteArrowTransport as SQLiteArrowStreamTransport, SQLiteArrowTransportError as SQLiteArrowStreamTransportError, }; +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +pub use trino_arrow::{TrinoArrowTransport, TrinoArrowTransportError}; +#[cfg(all(feature = "src_trino", feature = "dst_arrow2"))] +pub use trino_arrow2::{TrinoArrow2Transport, TrinoArrow2TransportError}; +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +pub use trino_arrowstream::{ + TrinoArrowTransport as TrinoArrowStreamTransport, + TrinoArrowTransportError as TrinoArrowStreamTransportError, +}; diff --git a/connectorx/src/transports/trino_arrow.rs b/connectorx/src/transports/trino_arrow.rs new file mode 100644 index 000000000..d498fb615 --- /dev/null +++ b/connectorx/src/transports/trino_arrow.rs @@ -0,0 +1,64 @@ +//! Transport from Trino Source to Arrow Destination. + +use crate::{ + destinations::arrow::{ + typesystem::ArrowTypeSystem, ArrowDestination, ArrowDestinationError, + }, + impl_transport, + sources::trino::{TrinoSource, TrinoSourceError, TrinoTypeSystem}, + typesystem::TypeConversion, +}; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use num_traits::ToPrimitive; +use rust_decimal::Decimal; +use serde_json::{to_string, Value}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoArrowTransportError { + #[error(transparent)] + Source(#[from] TrinoSourceError), + + #[error(transparent)] + Destination(#[from] ArrowDestinationError), + + #[error(transparent)] + ConnectorX(#[from] crate::errors::ConnectorXError), +} + +/// Convert Trino data types to Arrow data types. +pub struct TrinoArrowTransport(); + +impl_transport!( + name = TrinoArrowTransport, + error = TrinoArrowTransportError, + systems = TrinoTypeSystem => ArrowTypeSystem, + route = TrinoSource => ArrowDestination, + mappings = { + { Date[NaiveDate] => Date32[NaiveDate] | conversion auto } + { Time[NaiveTime] => Time64[NaiveTime] | conversion auto } + { Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto } + { Boolean[bool] => Boolean[bool] | conversion auto } + { Bigint[i32] => Int64[i64] | conversion auto } + { Integer[i32] => Int64[i64] | conversion none } + { Smallint[i16] => Int64[i64] | conversion auto } + { Tinyint[i8] => Int64[i64] | conversion auto } + { Double[f64] => Float64[f64] | conversion auto } + { Real[f32] => Float64[f64] | conversion auto } + { Varchar[String] => LargeUtf8[String] | conversion auto } + { Char[String] => LargeUtf8[String] | conversion none } + } +); + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Decimal) -> f64 { + val.to_f64() + .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", val)) + } +} + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Value) -> String { + to_string(&val).unwrap() + } +} diff --git a/connectorx/src/transports/trino_arrow2.rs b/connectorx/src/transports/trino_arrow2.rs new file mode 100644 index 000000000..bc31fe646 --- /dev/null +++ b/connectorx/src/transports/trino_arrow2.rs @@ -0,0 +1,64 @@ +//! Transport from Trino Source to Arrow2 Destination. + +use crate::{ + destinations::arrow2::{ + typesystem::Arrow2TypeSystem, Arrow2Destination, Arrow2DestinationError, + }, + impl_transport, + sources::trino::{TrinoSource, TrinoSourceError, TrinoTypeSystem}, + typesystem::TypeConversion, +}; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use num_traits::ToPrimitive; +use rust_decimal::Decimal; +use serde_json::{to_string, Value}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoArrow2TransportError { + #[error(transparent)] + Source(#[from] TrinoSourceError), + + #[error(transparent)] + Destination(#[from] Arrow2DestinationError), + + #[error(transparent)] + ConnectorX(#[from] crate::errors::ConnectorXError), +} + +/// Convert Trino data types to Arrow2 data types. +pub struct TrinoArrow2Transport(); + +impl_transport!( + name = TrinoArrow2Transport, + error = TrinoArrow2TransportError, + systems = TrinoTypeSystem => Arrow2TypeSystem, + route = TrinoSource => Arrow2Destination, + mappings = { + { Date[NaiveDate] => Date32[NaiveDate] | conversion auto } + { Time[NaiveTime] => Time64[NaiveTime] | conversion auto } + { Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto } + { Boolean[bool] => Boolean[bool] | conversion auto } + { Bigint[i32] => Int64[i64] | conversion auto } + { Integer[i32] => Int64[i64] | conversion none } + { Smallint[i16] => Int64[i64] | conversion auto } + { Tinyint[i8] => Int64[i64] | conversion auto } + { Double[f64] => Float64[f64] | conversion auto } + { Real[f32] => Float64[f64] | conversion auto } + { Varchar[String] => LargeUtf8[String] | conversion auto } + { Char[String] => LargeUtf8[String] | conversion none } + } +); + +impl TypeConversion for TrinoArrow2Transport { + fn convert(val: Decimal) -> f64 { + val.to_f64() + .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", val)) + } +} + +impl TypeConversion for TrinoArrow2Transport { + fn convert(val: Value) -> String { + to_string(&val).unwrap() + } +} diff --git a/connectorx/src/transports/trino_arrowstream.rs b/connectorx/src/transports/trino_arrowstream.rs new file mode 100644 index 000000000..f2b9e220c --- /dev/null +++ b/connectorx/src/transports/trino_arrowstream.rs @@ -0,0 +1,64 @@ +//! Transport from Trino Source to Arrow Destination. + +use crate::{ + destinations::arrowstream::{ + typesystem::ArrowTypeSystem, ArrowDestination, ArrowDestinationError, + }, + impl_transport, + sources::trino::{TrinoSource, TrinoSourceError, TrinoTypeSystem}, + typesystem::TypeConversion, +}; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use num_traits::ToPrimitive; +use rust_decimal::Decimal; +use serde_json::{to_string, Value}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoArrowTransportError { + #[error(transparent)] + Source(#[from] TrinoSourceError), + + #[error(transparent)] + Destination(#[from] ArrowDestinationError), + + #[error(transparent)] + ConnectorX(#[from] crate::errors::ConnectorXError), +} + +/// Convert Trino data types to Arrow data types. +pub struct TrinoArrowTransport(); + +impl_transport!( + name = TrinoArrowTransport, + error = TrinoArrowTransportError, + systems = TrinoTypeSystem => ArrowTypeSystem, + route = TrinoSource => ArrowDestination, + mappings = { + { Date[NaiveDate] => Date32[NaiveDate] | conversion auto } + { Time[NaiveTime] => Time64[NaiveTime] | conversion auto } + { Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto } + { Boolean[bool] => Boolean[bool] | conversion auto } + { Bigint[i32] => Int64[i64] | conversion auto } + { Integer[i32] => Int64[i64] | conversion none } + { Smallint[i16] => Int64[i64] | conversion auto } + { Tinyint[i8] => Int64[i64] | conversion auto } + { Double[f64] => Float64[f64] | conversion auto } + { Real[f32] => Float64[f64] | conversion auto } + { Varchar[String] => LargeUtf8[String] | conversion auto } + { Char[String] => LargeUtf8[String] | conversion none } + } +); + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Decimal) -> f64 { + val.to_f64() + .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", val)) + } +} + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Value) -> String { + to_string(&val).unwrap() + } +} diff --git a/connectorx/tests/test_trino.rs b/connectorx/tests/test_trino.rs new file mode 100644 index 000000000..8aa7d5f1c --- /dev/null +++ b/connectorx/tests/test_trino.rs @@ -0,0 +1,112 @@ +use arrow::{ + array::{Float64Array, Int64Array}, + record_batch::RecordBatch, +}; +use connectorx::{ + destinations::arrow::ArrowDestination, prelude::*, sources::trino::TrinoSource, sql::CXQuery, + transports::TrinoArrowTransport, +}; +use std::{env, sync::Arc}; + +#[test] +fn test_trino() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("TRINO_URL").unwrap(); + + let queries = [ + CXQuery::naked("select * from test.test_table where test_int <= 2 order by test_int"), + CXQuery::naked("select * from test.test_table where test_int > 2 order by test_int"), + ]; + + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let builder = TrinoSource::new(rt, &dburl).unwrap(); + let mut destination = ArrowDestination::new(); + let dispatcher = Dispatcher::<_, _, TrinoArrowTransport>::new( + builder, + &mut destination, + &queries, + Some(String::from( + "select * from test.test_table order by test_int", + )), + ); + dispatcher.run().unwrap(); + + let result = destination.arrow().unwrap(); + verify_arrow_results(result); +} + +#[test] +fn test_trino_text() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("TRINO_URL").unwrap(); + + let queries = [ + CXQuery::naked("select * from test.test_table where test_int <= 2 order by test_int"), + CXQuery::naked("select * from test.test_table where test_int > 2 order by test_int"), + ]; + + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let builder = TrinoSource::new(rt, &dburl).unwrap(); + let mut destination = ArrowDestination::new(); + let dispatcher = + Dispatcher::<_, _, TrinoArrowTransport>::new(builder, &mut destination, &queries, None); + dispatcher.run().unwrap(); + + let result = destination.arrow().unwrap(); + verify_arrow_results(result); +} + +pub fn verify_arrow_results(result: Vec) { + assert!(result.len() == 2); + + for r in result { + match r.num_rows() { + 2 => { + assert!(r + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![1, 2]))); + assert!(r + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Float64Array::from(vec![1.1, 2.2]))); + assert!(r + .column(2) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![None, None]))); + } + 4 => { + assert!(r + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![3, 4, 5, 6]))); + assert!(r + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Float64Array::from(vec![3.3, 4.4, 5.5, 6.6]))); + assert!(r + .column(2) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![None, None, None, None]))); + } + _ => { + println!("got {} rows in a record batch!", r.num_rows()); + unreachable!() + } + } + } +} diff --git a/docs/databases.md b/docs/databases.md index 6eb371fe9..077f99f2a 100644 --- a/docs/databases.md +++ b/docs/databases.md @@ -7,4 +7,6 @@ ConnectorX supports retrieving data from Postgres, MsSQL, MySQL, Oracle, SQLite, * [MySQL](./databases/mysql.md) * [Oracle](./databases/oracle.md) * [Postgres](./databases/postgres.md) -* [SQLite](./databases/sqlite.md) \ No newline at end of file +* [SQLite](./databases/sqlite.md) +* [Trino](./databases/trino.md) + diff --git a/docs/databases/trino.md b/docs/databases/trino.md new file mode 100644 index 000000000..8ea640d38 --- /dev/null +++ b/docs/databases/trino.md @@ -0,0 +1,36 @@ +# Trino + +## Postgres Connection + +```{hint} +Using `trino+http` as connection protocol disables SSL for the connection. Example: `trino+http://host:port/catalog +Notice that basic auth requires SSL for Trino. +``` + +```py +import connectorx as cx +conn = 'trino+https://username:password@server:port/catalog' # connection token +query = "SELECT * FROM table" # query string +cx.read_sql(conn, query) # read data from Trino +``` + +## Trino-Pandas Type Mapping + +| Trino Type | Pandas Type | Comment | +| :--------: | :---------------------: | :-----: | +| BOOLEAN | bool, boolean(nullable) | | +| TINYINT | int64, Int64(nullable) | | +| SMALLINT | int64, Int64(nullable) | | +| INT | int64, Int64(nullable) | | +| BIGINT | int64, Int64(nullable) | | +| REAL | float64 | | +| DOUBLE | float64 | | +| DECIMAL | float64 | | +| VARCHAR | object | | +| CHAR | object | | +| DATE | datetime64[ns] | | +| TIME | object | | +| TIMESTAMP | datetime64[ns] | | +| UUID | object | | +| JSON | object | | +| IPADDRESS | object | | diff --git a/scripts/trino.sql b/scripts/trino.sql new file mode 100644 index 000000000..643984dec --- /dev/null +++ b/scripts/trino.sql @@ -0,0 +1,49 @@ +CREATE SCHEMA IF NOT EXISTS test; + +CREATE TABLE IF NOT EXISTS test.test_table( + test_int INTEGER, + test_float DOUBLE, + test_null INTEGER +); + +DELETE FROM test.test_table; +INSERT INTO test.test_table VALUES (1, 1.1, NULL); +INSERT INTO test.test_table VALUES (2, 2.2, NULL); +INSERT INTO test.test_table VALUES (3, 3.3, NULL); +INSERT INTO test.test_table VALUES (4, 4.4, NULL); +INSERT INTO test.test_table VALUES (5, 5.5, NULL); +INSERT INTO test.test_table VALUES (6, 6.6, NULL); + +DROP TABLE IF EXISTS test.test_table_extra; + +CREATE TABLE IF NOT EXISTS test.test_table_extra( + test_int INTEGER, + test_str VARCHAR(30) +); + +DELETE FROM test.test_table_extra; +INSERT INTO test.test_table_extra VALUES (1, 'Ha好ち😁ðy̆'); +INSERT INTO test.test_table_extra VALUES (2, 'こんにちは'); +INSERT INTO test.test_table_extra VALUES (3, 'русский'); + +DROP TABLE IF EXISTS test.test_types; + +CREATE TABLE IF NOT EXISTS test.test_types( + test_boolean BOOLEAN, + test_int INT, + test_bigint BIGINT, + test_real REAL, + test_double DOUBLE, + test_decimal DECIMAL(15,2), + test_date DATE, + test_time TIME(6), + test_timestamp TIMESTAMP(6), + test_varchar VARCHAR(15), + test_uuid UUID -- TODO: VARBINARY, ROW, ARRAY, MAP +); + +DELETE FROM test.test_types; +INSERT INTO test.test_types (test_boolean, test_int, test_bigint, test_real, test_double, test_decimal, test_date, test_time, test_timestamp, test_varchar, test_uuid) VALUES +(TRUE, 123, 1000, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', CAST('f4967dbb-33e9-4242-a13a-45b56ce60dba' AS UUID)), +(FALSE, 321, 2000, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', CAST('1c8b79d0-4508-4974-b728-7651bce4a5a5' AS UUID)), +(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL);