diff --git a/poetry.lock b/poetry.lock index 0eb3b1ec4..b720978fb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5541,6 +5541,17 @@ files = [ {file = "twisted_iocpsupport-1.0.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:300437af17396a945a58dcfffd77863303a8b6d9e65c6e81f1d2eed55b50d444"}, ] +[[package]] +name = "types-psycopg2" +version = "2.9.21.20240218" +description = "Typing stubs for psycopg2" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-psycopg2-2.9.21.20240218.tar.gz", hash = "sha256:3084cd807038a62c80fb5be78b41d855b48a060316101ea59fd85c302efb57d4"}, + {file = "types_psycopg2-2.9.21.20240218-py3-none-any.whl", hash = "sha256:cac96264e063cbce28dee337a973d39e6df4ca671252343cb4f8e5ef6db5e67d"}, +] + [[package]] name = "types-pytz" version = "2023.3.1.1" @@ -6352,4 +6363,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "7f07a18a1d2b6ab04b964d6236676d60c0c0551c95f361bb9065c9b268207768" +content-hash = "ae9798483262d3ecccae313723ac6b67fc416a71be8a3a7a467b78624950e0cc" diff --git a/pyproject.toml b/pyproject.toml index de2c4f056..f6dbf3199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ black = "^23.3.0" pypdf2 = "^3.0.1" greenlet = "<3.0.0" confluent-kafka = "^2.3.0" +types-psycopg2 = "^2.9.0" pytest-mock = "^3.12.0" twisted = "22.10.0" pytest-forked = "^1.6.0" @@ -41,6 +42,9 @@ sqlalchemy = ">=1.4" pymysql = "^1.0.3" connectorx = ">=0.3.1" +[tool.poetry.group.pg_replication.dependencies] +psycopg2-binary = ">=2.9.9" + [tool.poetry.group.google_sheets.dependencies] google-api-python-client = "^2.78.0" diff --git a/sources/.dlt/example.secrets.toml b/sources/.dlt/example.secrets.toml index 63126e597..a0e8963e0 100644 --- a/sources/.dlt/example.secrets.toml +++ b/sources/.dlt/example.secrets.toml @@ -19,4 +19,4 @@ location = "US" ## chess pipeline # the section below defines secrets for "chess_dlt_config_example" source in chess/__init__.py [sources.chess] -secret_str="secret string" # a string secret \ No newline at end of file +secret_str="secret string" # a string secret diff --git a/sources/filesystem/helpers.py b/sources/filesystem/helpers.py index 6d099599c..f241c6160 100644 --- a/sources/filesystem/helpers.py +++ b/sources/filesystem/helpers.py @@ -1,5 +1,5 @@ """Helpers for the filesystem resource.""" -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union from fsspec import AbstractFileSystem # type: ignore from dlt.common.configuration import resolve_type diff --git a/sources/pg_replication/README.md b/sources/pg_replication/README.md new file mode 100644 index 000000000..f34fcd4d6 --- /dev/null +++ b/sources/pg_replication/README.md @@ -0,0 +1,79 @@ +# Postgres replication +[Postgres](https://www.postgresql.org/) is one of the most popular relational database management systems. This verified source uses Postgres' replication functionality to efficiently process changes in tables (a process often referred to as _Change Data Capture_ or CDC). It uses [logical decoding](https://www.postgresql.org/docs/current/logicaldecoding.html) and the standard built-in `pgoutput` [output plugin](https://www.postgresql.org/docs/current/logicaldecoding-output-plugin.html). + +Resources that can be loaded using this verified source are: + +| Name | Description | +|----------------------|-------------------------------------------------| +| replication_resource | Load published messages from a replication slot | + +## Initialize the pipeline + +```bash +dlt init pg_replication duckdb +``` + +This uses `duckdb` as destination, but you can choose any of the supported [destinations](https://dlthub.com/docs/dlt-ecosystem/destinations/). + +## Add `sql_database` source + +```bash +dlt init sql_database duckdb +``` + +This source depends on the [sql_database](../sql_database/README.md) verified source internally to perform initial loads. This step can be skipped if you don't do initial loads. +## Set up user + +The Postgres user needs to have the `LOGIN` and `REPLICATION` attributes assigned: + +```sql +CREATE ROLE replication_user WITH LOGIN REPLICATION; +``` + +It also needs `CREATE` privilege on the database: + +```sql +GRANT CREATE ON DATABASE dlt_data TO replication_user; +``` + +### Set up RDS +1. You must enable replication for RDS Postgres instance via **Parameter Group**: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_PostgreSQL.Replication.ReadReplicas.html +2. `WITH LOGIN REPLICATION;` does not work on RDS, instead do: +```sql +GRANT rds_replication TO replication_user; +``` +3. Do not fallback to non SSL connection by setting connection parameters: +```toml +sources.pg_replication.credentials="postgresql://loader:password@host.rds.amazonaws.com:5432/dlt_data?sslmode=require&connect_timeout=300" +``` + + +## Add credentials +1. Open `.dlt/secrets.toml`. +2. Enter your Postgres credentials: + + ```toml + [sources.pg_replication] + credentials="postgresql://replication_user:<>@localhost:5432/dlt_data" + ``` +3. Enter credentials for your chosen destination as per the [docs](https://dlthub.com/docs/dlt-ecosystem/destinations/). + +## Run the pipeline + +1. Install the necessary dependencies by running the following command: + + ```bash + pip install -r requirements.txt + ``` + +1. Now the pipeline can be run by using the command: + + ```bash + python pg_replication_pipeline.py + ``` + +1. To make sure that everything is loaded as expected, use the command: + + ```bash + dlt pipeline pg_replication_pipeline show + ``` \ No newline at end of file diff --git a/sources/pg_replication/__init__.py b/sources/pg_replication/__init__.py new file mode 100644 index 000000000..9ec0e9b7b --- /dev/null +++ b/sources/pg_replication/__init__.py @@ -0,0 +1,103 @@ +"""Replicates postgres tables in batch using logical decoding.""" + +from typing import Dict, Sequence, Optional, Iterable, Union + +import dlt + +from dlt.common.typing import TDataItem +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.extract.items import DataItemWithMeta +from dlt.sources.credentials import ConnectionStringCredentials + +from .helpers import advance_slot, get_max_lsn, ItemGenerator + + +@dlt.resource( + name=lambda args: args["slot_name"] + "_" + args["pub_name"], + standalone=True, +) +def replication_resource( + slot_name: str, + pub_name: str, + credentials: ConnectionStringCredentials = dlt.secrets.value, + include_columns: Optional[Dict[str, Sequence[str]]] = None, + columns: Optional[Dict[str, TTableSchemaColumns]] = None, + target_batch_size: int = 1000, + flush_slot: bool = True, +) -> Iterable[Union[TDataItem, DataItemWithMeta]]: + """Resource yielding data items for changes in one or more postgres tables. + + - Relies on a replication slot and publication that publishes DML operations + (i.e. `insert`, `update`, and/or `delete`). Helper `init_replication` can be + used to set this up. + - Maintains LSN of last consumed message in state to track progress. + - At start of the run, advances the slot upto last consumed message in previous run. + - Processes in batches to limit memory usage. + + Args: + slot_name (str): Name of the replication slot to consume replication messages from. + pub_name (str): Name of the publication that publishes DML operations for the table(s). + credentials (ConnectionStringCredentials): Postgres database credentials. + include_columns (Optional[Dict[str, Sequence[str]]]): Maps table name(s) to + sequence of names of columns to include in the generated data items. + Any column not in the sequence is excluded. If not provided, all columns + are included. For example: + ``` + include_columns={ + "table_x": ["col_a", "col_c"], + "table_y": ["col_x", "col_y", "col_z"], + } + ``` + columns (Optional[Dict[str, TTableHintTemplate[TAnySchemaColumns]]]): Maps + table name(s) to column hints to apply on the replicated table(s). For example: + ``` + columns={ + "table_x": {"col_a": {"data_type": "complex"}}, + "table_y": {"col_y": {"precision": 32}}, + } + ``` + target_batch_size (int): Desired number of data items yielded in a batch. + Can be used to limit the data items in memory. Note that the number of + data items yielded can be (far) greater than `target_batch_size`, because + all messages belonging to the same transaction are always processed in + the same batch, regardless of the number of messages in the transaction + and regardless of the value of `target_batch_size`. The number of data + items can also be smaller than `target_batch_size` when the replication + slot is exhausted before a batch is full. + flush_slot (bool): Whether processed messages are discarded from the replication + slot. Recommended value is True. Be careful when setting False—not flushing + can eventually lead to a “disk full” condition on the server, because + the server retains all the WAL segments that might be needed to stream + the changes via all of the currently open replication slots. + + Yields: + Data items for changes published in the publication. + """ + # start where we left off in previous run + start_lsn = dlt.current.resource_state().get("last_commit_lsn", 0) + if flush_slot: + advance_slot(start_lsn, slot_name, credentials) + + # continue until last message in replication slot + options = {"publication_names": pub_name, "proto_version": "1"} + upto_lsn = get_max_lsn(slot_name, options, credentials) + if upto_lsn is None: + return "Replication slot is empty." + + # generate items in batches + while True: + gen = ItemGenerator( + credentials=credentials, + slot_name=slot_name, + options=options, + upto_lsn=upto_lsn, + start_lsn=start_lsn, + target_batch_size=target_batch_size, + include_columns=include_columns, + columns=columns, + ) + yield from gen + if gen.generated_all: + dlt.current.resource_state()["last_commit_lsn"] = gen.last_commit_lsn + break + start_lsn = gen.last_commit_lsn diff --git a/sources/pg_replication/decoders.py b/sources/pg_replication/decoders.py new file mode 100644 index 000000000..c2707b46a --- /dev/null +++ b/sources/pg_replication/decoders.py @@ -0,0 +1,427 @@ +# flake8: noqa +# file copied from https://raw.githubusercontent.com/dgea005/pypgoutput/master/src/pypgoutput/decoders.py +# we do this instead of importing `pypgoutput` because it depends on `psycopg2`, which causes errors when installing on macOS + +import io +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import List, Optional, Union + +# integer byte lengths +INT8 = 1 +INT16 = 2 +INT32 = 4 +INT64 = 8 + + +def convert_pg_ts(_ts_in_microseconds: int) -> datetime: + ts = datetime(2000, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc) + return ts + timedelta(microseconds=_ts_in_microseconds) + + +def convert_bytes_to_int(_in_bytes: bytes) -> int: + return int.from_bytes(_in_bytes, byteorder="big", signed=True) + + +def convert_bytes_to_utf8(_in_bytes: Union[bytes, bytearray]) -> str: + return (_in_bytes).decode("utf-8") + + +@dataclass(frozen=True) +class ColumnData: + # col_data_category is NOT the type. it means null value/toasted(not sent)/text formatted + col_data_category: Optional[str] + col_data_length: Optional[int] = None + col_data: Optional[str] = None + + def __repr__(self) -> str: + return f"[col_data_category='{self.col_data_category}', col_data_length={self.col_data_length}, col_data='{self.col_data}']" + + +@dataclass(frozen=True) +class ColumnType: + """https://www.postgresql.org/docs/12/catalog-pg-attribute.html""" + + part_of_pkey: int + name: str + type_id: int + atttypmod: int + + +@dataclass(frozen=True) +class TupleData: + n_columns: int + column_data: List[ColumnData] + + def __repr__(self) -> str: + return f"n_columns: {self.n_columns}, data: {self.column_data}" + + +class PgoutputMessage(ABC): + def __init__(self, buffer: bytes): + self.buffer: io.BytesIO = io.BytesIO(buffer) + self.byte1: str = self.read_utf8(1) + self.decode_buffer() + + @abstractmethod + def decode_buffer(self) -> None: + """Decoding is implemented for each message type""" + + @abstractmethod + def __repr__(self) -> str: + """Implemented for each message type""" + + def read_int8(self) -> int: + return convert_bytes_to_int(self.buffer.read(INT8)) + + def read_int16(self) -> int: + return convert_bytes_to_int(self.buffer.read(INT16)) + + def read_int32(self) -> int: + return convert_bytes_to_int(self.buffer.read(INT32)) + + def read_int64(self) -> int: + return convert_bytes_to_int(self.buffer.read(INT64)) + + def read_utf8(self, n: int = 1) -> str: + return convert_bytes_to_utf8(self.buffer.read(n)) + + def read_timestamp(self) -> datetime: + # 8 chars -> int64 -> timestamp + return convert_pg_ts(_ts_in_microseconds=self.read_int64()) + + def read_string(self) -> str: + output = bytearray() + while (next_char := self.buffer.read(1)) != b"\x00": + output += next_char + return convert_bytes_to_utf8(output) + + def read_tuple_data(self) -> TupleData: + """ + TupleData + Int16 Number of columns. + Next, one of the following submessages appears for each column (except generated columns): + Byte1('n') Identifies the data as NULL value. + Or + Byte1('u') Identifies unchanged TOASTed value (the actual value is not sent). + Or + Byte1('t') Identifies the data as text formatted value. + Int32 Length of the column value. + Byten The value of the column, in text format. (A future release might support additional formats.) n is the above length. + """ + # TODO: investigate what happens with the generated columns + column_data = list() + n_columns = self.read_int16() + for column in range(n_columns): + col_data_category = self.read_utf8() + if col_data_category in ("n", "u"): + # "n"=NULL, "t"=TOASTed + column_data.append(ColumnData(col_data_category=col_data_category)) + elif col_data_category == "t": + # t = tuple + col_data_length = self.read_int32() + col_data = self.read_utf8(col_data_length) + column_data.append( + ColumnData( + col_data_category=col_data_category, + col_data_length=col_data_length, + col_data=col_data, + ) + ) + return TupleData(n_columns=n_columns, column_data=column_data) + + +class Begin(PgoutputMessage): + """ + https://pgpedia.info/x/xlogrecptr.html + https://www.postgresql.org/docs/14/datatype-pg-lsn.html + + byte1 Byte1('B') Identifies the message as a begin message. + lsn Int64 The final LSN of the transaction. + commit_tx_ts Int64 Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + tx_xid Int32 Xid of the transaction. + """ + + byte1: str + lsn: int + commit_ts: datetime + tx_xid: int + + def decode_buffer(self) -> None: + if self.byte1 != "B": + raise ValueError("first byte in buffer does not match Begin message") + self.lsn = self.read_int64() + self.commit_ts = self.read_timestamp() + self.tx_xid = self.read_int64() + + def __repr__(self) -> str: + return ( + f"BEGIN \n\tbyte1: '{self.byte1}', \n\tLSN: {self.lsn}, " + f"\n\tcommit_ts {self.commit_ts}, \n\ttx_xid: {self.tx_xid}" + ) + + +class Commit(PgoutputMessage): + """ + byte1: Byte1('C') Identifies the message as a commit message. + flags: Int8 Flags; currently unused (must be 0). + lsn_commit: Int64 The LSN of the commit. + lsn: Int64 The end LSN of the transaction. + Int64 Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + """ + + byte1: str + flags: int + lsn_commit: int + lsn: int + commit_ts: datetime + + def decode_buffer(self) -> None: + if self.byte1 != "C": + raise ValueError("first byte in buffer does not match Commit message") + self.flags = self.read_int8() + self.lsn_commit = self.read_int64() + self.lsn = self.read_int64() + self.commit_ts = self.read_timestamp() + + def __repr__(self) -> str: + return ( + f"COMMIT \n\tbyte1: {self.byte1}, \n\tflags {self.flags}, \n\tlsn_commit: {self.lsn_commit}" + f"\n\tLSN: {self.lsn}, \n\tcommit_ts {self.commit_ts}" + ) + + +class Origin: + """ + Byte1('O') Identifies the message as an origin message. + Int64 The LSN of the commit on the origin server. + String Name of the origin. + Note that there can be multiple Origin messages inside a single transaction. + This seems to be what origin means: https://www.postgresql.org/docs/12/replication-origins.html + """ + + pass + + +class Relation(PgoutputMessage): + """ + Byte1('R') Identifies the message as a relation message. + Int32 ID of the relation. + String Namespace (empty string for pg_catalog). + String Relation name. + Int8 Replica identity setting for the relation (same as relreplident in pg_class). + # select relreplident from pg_class where relname = 'test_table'; + # from reading the documentation and looking at the tables this is not int8 but a single character + # background: https://www.postgresql.org/docs/10/sql-altertable.html#SQL-CREATETABLE-REPLICA-IDENTITY + Int16 Number of columns. + Next, the following message part appears for each column (except generated columns): + Int8 Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as part of the key. + String Name of the column. + Int32 ID of the column's data type. + Int32 Type modifier of the column (atttypmod). + """ + + byte1: str + relation_id: int + namespace: str + relation_name: str + replica_identity_setting: str + n_columns: int + columns: List[ColumnType] + + def decode_buffer(self) -> None: + if self.byte1 != "R": + raise ValueError("first byte in buffer does not match Relation message") + self.relation_id = self.read_int32() + self.namespace = self.read_string() + self.relation_name = self.read_string() + self.replica_identity_setting = self.read_utf8() + self.n_columns = self.read_int16() + self.columns = list() + + for column in range(self.n_columns): + part_of_pkey = self.read_int8() + col_name = self.read_string() + data_type_id = self.read_int32() + # TODO: check on use of signed / unsigned + # check with select oid from pg_type where typname = ; timestamp == 1184, int4 = 23 + col_modifier = self.read_int32() + self.columns.append( + ColumnType( + part_of_pkey=part_of_pkey, + name=col_name, + type_id=data_type_id, + atttypmod=col_modifier, + ) + ) + + def __repr__(self) -> str: + return ( + f"RELATION \n\tbyte1: '{self.byte1}', \n\trelation_id: {self.relation_id}" + f",\n\tnamespace/schema: '{self.namespace}',\n\trelation_name: '{self.relation_name}'" + f",\n\treplica_identity_setting: '{self.replica_identity_setting}',\n\tn_columns: {self.n_columns} " + f",\n\tcolumns: {self.columns}" + ) + + +class PgType: + """ + Renamed to PgType not to collide with "type" + + Byte1('Y') Identifies the message as a type message. + Int32 ID of the data type. + String Namespace (empty string for pg_catalog). + String Name of the data type. + """ + + pass + + +class Insert(PgoutputMessage): + """ + Byte1('I') Identifies the message as an insert message. + Int32 ID of the relation corresponding to the ID in the relation message. + Byte1('N') Identifies the following TupleData message as a new tuple. + TupleData TupleData message part representing the contents of new tuple. + """ + + byte1: str + relation_id: int + new_tuple_byte: str + new_tuple: TupleData + + def decode_buffer(self) -> None: + if self.byte1 != "I": + raise ValueError( + f"first byte in buffer does not match Insert message (expected 'I', got '{self.byte1}'" + ) + self.relation_id = self.read_int32() + self.new_tuple_byte = self.read_utf8() + self.new_tuple = self.read_tuple_data() + + def __repr__(self) -> str: + return ( + f"INSERT \n\tbyte1: '{self.byte1}', \n\trelation_id: {self.relation_id} " + f"\n\tnew tuple byte: '{self.new_tuple_byte}', \n\tnew_tuple: {self.new_tuple}" + ) + + +class Update(PgoutputMessage): + """ + Byte1('U') Identifies the message as an update message. + Int32 ID of the relation corresponding to the ID in the relation message. + Byte1('K') Identifies the following TupleData submessage as a key. This field is optional and is only present if the update changed data in any of the column(s) that are part of the REPLICA IDENTITY index. + Byte1('O') Identifies the following TupleData submessage as an old tuple. This field is optional and is only present if table in which the update happened has REPLICA IDENTITY set to FULL. + TupleData TupleData message part representing the contents of the old tuple or primary key. Only present if the previous 'O' or 'K' part is present. + Byte1('N') Identifies the following TupleData message as a new tuple. + TupleData TupleData message part representing the contents of a new tuple. + + The Update message may contain either a 'K' message part or an 'O' message part or neither of them, but never both of them. + """ + + byte1: str + relation_id: int + next_byte_identifier: Optional[str] + optional_tuple_identifier: Optional[str] + old_tuple: Optional[TupleData] + new_tuple_byte: str + new_tuple: TupleData + + def decode_buffer(self) -> None: + self.optional_tuple_identifier = None + self.old_tuple = None + if self.byte1 != "U": + raise ValueError( + f"first byte in buffer does not match Update message (expected 'U', got '{self.byte1}'" + ) + self.relation_id = self.read_int32() + # TODO test update to PK, test update with REPLICA IDENTITY = FULL + self.next_byte_identifier = self.read_utf8() # one of K, O or N + if self.next_byte_identifier == "K" or self.next_byte_identifier == "O": + self.optional_tuple_identifier = self.next_byte_identifier + self.old_tuple = self.read_tuple_data() + self.new_tuple_byte = self.read_utf8() + else: + self.new_tuple_byte = self.next_byte_identifier + if self.new_tuple_byte != "N": + # TODO: test exception handling + raise ValueError( + f"did not find new_tuple_byte ('N') at position: {self.buffer.tell()}, found: '{self.new_tuple_byte}'" + ) + self.new_tuple = self.read_tuple_data() + + def __repr__(self) -> str: + return ( + f"UPDATE \n\tbyte1: '{self.byte1}', \n\trelation_id: {self.relation_id}" + f"\n\toptional_tuple_identifier: '{self.optional_tuple_identifier}', \n\toptional_old_tuple_data: {self.old_tuple}" + f"\n\tnew_tuple_byte: '{self.new_tuple_byte}', \n\tnew_tuple: {self.new_tuple}" + ) + + +class Delete(PgoutputMessage): + """ + Byte1('D') Identifies the message as a delete message. + Int32 ID of the relation corresponding to the ID in the relation message. + Byte1('K') Identifies the following TupleData submessage as a key. This field is present if the table in which the delete has happened uses an index as REPLICA IDENTITY. + Byte1('O') Identifies the following TupleData message as a old tuple. This field is present if the table in which the delete has happened has REPLICA IDENTITY set to FULL. + TupleData TupleData message part representing the contents of the old tuple or primary key, depending on the previous field. + + The Delete message may contain either a 'K' message part or an 'O' message part, but never both of them. + """ + + byte1: str + relation_id: int + message_type: str + old_tuple: TupleData + + def decode_buffer(self) -> None: + if self.byte1 != "D": + raise ValueError( + f"first byte in buffer does not match Delete message (expected 'D', got '{self.byte1}'" + ) + self.relation_id = self.read_int32() + self.message_type = self.read_utf8() + # TODO: test with replica identity full + if self.message_type not in ["K", "O"]: + raise ValueError( + f"message type byte is not 'K' or 'O', got: '{self.message_type}'" + ) + self.old_tuple = self.read_tuple_data() + + def __repr__(self) -> str: + return ( + f"DELETE \n\tbyte1: {self.byte1} \n\trelation_id: {self.relation_id} " + f"\n\tmessage_type: {self.message_type} \n\told_tuple: {self.old_tuple}" + ) + + +class Truncate(PgoutputMessage): + """ + Byte1('T') Identifies the message as a truncate message. + Int32 Number of relations + Int8 Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY + Int32 ID of the relation corresponding to the ID in the relation message. This field is repeated for each relation. + """ + + byte1: str + number_of_relations: int + option_bits: int + relation_ids: List[int] + + def decode_buffer(self) -> None: + if self.byte1 != "T": + raise ValueError( + f"first byte in buffer does not match Truncate message (expected 'T', got '{self.byte1}'" + ) + self.number_of_relations = self.read_int32() + self.option_bits = self.read_int8() + self.relation_ids = [] + for relation in range(self.number_of_relations): + self.relation_ids.append(self.read_int32()) + + def __repr__(self) -> str: + return ( + f"TRUNCATE \n\tbyte1: {self.byte1} \n\tn_relations: {self.number_of_relations} " + f"option_bits: {self.option_bits}, relation_ids: {self.relation_ids}" + ) diff --git a/sources/pg_replication/exceptions.py b/sources/pg_replication/exceptions.py new file mode 100644 index 000000000..2b2642777 --- /dev/null +++ b/sources/pg_replication/exceptions.py @@ -0,0 +1,6 @@ +class NoPrimaryKeyException(Exception): + pass + + +class IncompatiblePostgresVersionException(Exception): + pass diff --git a/sources/pg_replication/helpers.py b/sources/pg_replication/helpers.py new file mode 100644 index 000000000..112c0b1c6 --- /dev/null +++ b/sources/pg_replication/helpers.py @@ -0,0 +1,769 @@ +from typing import ( + Optional, + Dict, + Iterator, + Union, + List, + Sequence, + Any, +) +from dataclasses import dataclass, field + +import psycopg2 +from psycopg2.extensions import cursor +from psycopg2.extras import ( + LogicalReplicationConnection, + ReplicationCursor, + ReplicationMessage, + StopReplication, +) + +import dlt + +from dlt.common import logger +from dlt.common.typing import TDataItem +from dlt.common.pendulum import pendulum +from dlt.common.schema.typing import ( + TTableSchema, + TTableSchemaColumns, + TColumnNames, + TWriteDisposition, +) +from dlt.common.schema.utils import merge_column +from dlt.common.data_writers.escape import escape_postgres_identifier +from dlt.extract.items import DataItemWithMeta +from dlt.extract.resource import DltResource +from dlt.sources.credentials import ConnectionStringCredentials + +try: + from ..sql_database import sql_table # type: ignore[import-untyped] +except Exception: + from sql_database import sql_table + +from .schema_types import _to_dlt_column_schema, _to_dlt_val +from .exceptions import IncompatiblePostgresVersionException +from .decoders import ( + Begin, + Relation, + Insert, + Update, + Delete, + ColumnData, +) + + +@dlt.sources.config.with_config(sections=("sources", "pg_replication")) +def init_replication( + slot_name: str, + pub_name: str, + schema_name: str, + table_names: Optional[Union[str, Sequence[str]]] = None, + credentials: ConnectionStringCredentials = dlt.secrets.value, + publish: str = "insert, update, delete", + persist_snapshots: bool = False, + include_columns: Optional[Dict[str, Sequence[str]]] = None, + columns: Optional[Dict[str, TTableSchemaColumns]] = None, + reset: bool = False, +) -> Optional[Union[DltResource, List[DltResource]]]: + """Initializes replication for one, several, or all tables within a schema. + + Can be called repeatedly with the same `slot_name` and `pub_name`: + - creates a replication slot and publication with provided names if they do not exist yet + - skips creation of slot and publication if they already exist (unless`reset` is set to `False`) + - supports addition of new tables by extending `table_names` + - removing tables is not supported, i.e. exluding a table from `table_names` + will not remove it from the publication + - switching from a table selection to an entire schema is possible by omitting + the `table_names` argument + - changing `publish` has no effect (altering the published DML operations is not supported) + - table snapshots can only be persisted on the first call (because the snapshot + is exported when the slot is created) + + Args: + slot_name (str): Name of the replication slot to create if it does not exist yet. + pub_name (str): Name of the publication to create if it does not exist yet. + schema_name (str): Name of the schema to replicate tables from. + table_names (Optional[Union[str, Sequence[str]]]): Name(s) of the table(s) + to include in the publication. If not provided, all tables in the schema + are included (also tables added to the schema after the publication was created). + credentials (ConnectionStringCredentials): Postgres database credentials. + publish (str): Comma-separated string of DML operations. Can be used to + control which changes are included in the publication. Allowed operations + are `insert`, `update`, and `delete`. `truncate` is currently not + supported—messages of that type are ignored. + E.g. `publish="insert"` will create a publication that only publishes insert operations. + persist_snapshots (bool): Whether the table states in the snapshot exported + during replication slot creation are persisted to tables. If true, a + snapshot table is created in Postgres for all included tables, and corresponding + resources (`DltResource` objects) for these tables are created and returned. + The resources can be used to perform an initial load of all data present + in the tables at the moment the replication slot got created. + include_columns (Optional[Dict[str, Sequence[str]]]): Maps table name(s) to + sequence of names of columns to include in the snapshot table(s). + Any column not in the sequence is excluded. If not provided, all columns + are included. For example: + ``` + include_columns={ + "table_x": ["col_a", "col_c"], + "table_y": ["col_x", "col_y", "col_z"], + } + ``` + Argument is only used if `persist_snapshots` is `True`. + columns (Optional[Dict[str, TTableSchemaColumns]]): Maps + table name(s) to column hints to apply on the snapshot table resource(s). + For example: + ``` + columns={ + "table_x": {"col_a": {"data_type": "complex"}}, + "table_y": {"col_y": {"precision": 32}}, + } + ``` + Argument is only used if `persist_snapshots` is `True`. + reset (bool): If set to True, the existing slot and publication are dropped + and recreated. Has no effect if a slot and publication with the provided + names do not yet exist. + + Returns: + - None if `persist_snapshots` is `False` + - a `DltResource` object or a list of `DltResource` objects for the snapshot + table(s) if `persist_snapshots` is `True` and the replication slot did not yet exist + """ + if isinstance(table_names, str): + table_names = [table_names] + cur = _get_rep_conn(credentials).cursor() + if reset: + drop_replication_slot(slot_name, cur) + drop_publication(pub_name, cur) + create_publication(pub_name, cur, publish) + if table_names is None: + add_schema_to_publication(schema_name, pub_name, cur) + else: + add_tables_to_publication(table_names, schema_name, pub_name, cur) + slot = create_replication_slot(slot_name, cur) + if persist_snapshots: + if slot is None: + logger.info( + "Cannot persist snapshots because they do not exist. " + f'The replication slot "{slot_name}" already existed prior to calling this function.' + ) + else: + # need separate session to read the snapshot: https://stackoverflow.com/q/75852587 + cur_snap = _get_conn(credentials).cursor() + snapshot_table_names = [ + persist_snapshot_table( + snapshot_name=slot["snapshot_name"], + table_name=table_name, + schema_name=schema_name, + cur=cur_snap, + include_columns=None + if include_columns is None + else include_columns.get(table_name), + ) + for table_name in table_names + ] + snapshot_table_resources = [ + snapshot_table_resource( + snapshot_table_name=snapshot_table_name, + schema_name=schema_name, + table_name=table_name, + write_disposition="append" if publish == "insert" else "merge", + columns=None if columns is None else columns.get(table_name), + credentials=credentials, + ) + for table_name, snapshot_table_name in zip( + table_names, snapshot_table_names + ) + ] + if len(snapshot_table_resources) == 1: + return snapshot_table_resources[0] + return snapshot_table_resources + return None + + +@dlt.sources.config.with_config(sections=("sources", "pg_replication")) +def get_pg_version( + cur: cursor = None, + credentials: ConnectionStringCredentials = dlt.secrets.value, +) -> int: + """Returns Postgres server version as int.""" + if cur is not None: + return cur.connection.server_version + return _get_conn(credentials).server_version + + +def create_publication( + name: str, + cur: cursor, + publish: str = "insert, update, delete", +) -> None: + """Creates a publication for logical replication if it doesn't exist yet. + + Does nothing if the publication already exists. + Raises error if the user does not have the CREATE privilege for the database. + """ + esc_name = escape_postgres_identifier(name) + try: + cur.execute(f"CREATE PUBLICATION {esc_name} WITH (publish = '{publish}');") + logger.info( + f"Successfully created publication {esc_name} with publish = '{publish}'." + ) + except psycopg2.errors.DuplicateObject: # the publication already exists + logger.info(f'Publication "{name}" already exists.') + + +def add_table_to_publication( + table_name: str, + schema_name: str, + pub_name: str, + cur: cursor, +) -> None: + """Adds a table to a publication for logical replication. + + Does nothing if the table is already a member of the publication. + Raises error if the user is not owner of the table. + """ + qual_name = _make_qualified_table_name(table_name, schema_name) + esc_pub_name = escape_postgres_identifier(pub_name) + try: + cur.execute(f"ALTER PUBLICATION {esc_pub_name} ADD TABLE {qual_name};") + logger.info( + f"Successfully added table {qual_name} to publication {esc_pub_name}." + ) + except psycopg2.errors.DuplicateObject: + logger.info( + f"Table {qual_name} is already a member of publication {esc_pub_name}." + ) + + +def add_tables_to_publication( + table_names: Union[str, Sequence[str]], + schema_name: str, + pub_name: str, + cur: cursor, +) -> None: + """Adds one or multiple tables to a publication for logical replication. + + Calls `add_table_to_publication` for each table in `table_names`. + """ + if isinstance(table_names, str): + table_names = table_names + for table_name in table_names: + add_table_to_publication(table_name, schema_name, pub_name, cur) + + +def add_schema_to_publication( + schema_name: str, + pub_name: str, + cur: cursor, +) -> None: + """Adds a schema to a publication for logical replication if the schema is not a member yet. + + Raises error if the user is not a superuser. + """ + if (version := get_pg_version(cur)) < 150000: + raise IncompatiblePostgresVersionException( + f"Cannot add schema to publication because the Postgres server version {version} is too low." + " Adding schemas to a publication is only supported for Postgres version 15 or higher." + " Upgrade your Postgres server version or set the `table_names` argument to explicitly specify table names." + ) + esc_schema_name = escape_postgres_identifier(schema_name) + esc_pub_name = escape_postgres_identifier(pub_name) + try: + cur.execute( + f"ALTER PUBLICATION {esc_pub_name} ADD TABLES IN SCHEMA {esc_schema_name};" + ) + logger.info( + f"Successfully added schema {esc_schema_name} to publication {esc_pub_name}." + ) + except psycopg2.errors.DuplicateObject: + logger.info( + f"Schema {esc_schema_name} is already a member of publication {esc_pub_name}." + ) + + +def create_replication_slot( # type: ignore[return] + name: str, cur: ReplicationCursor, output_plugin: str = "pgoutput" +) -> Optional[Dict[str, str]]: + """Creates a replication slot if it doesn't exist yet.""" + try: + cur.create_replication_slot(name, output_plugin=output_plugin) + logger.info(f'Successfully created replication slot "{name}".') + result = cur.fetchone() + return { + "slot_name": result[0], + "consistent_point": result[1], + "snapshot_name": result[2], + "output_plugin": result[3], + } + except psycopg2.errors.DuplicateObject: # the replication slot already exists + logger.info( + f'Replication slot "{name}" cannot be created because it already exists.' + ) + + +def drop_replication_slot(name: str, cur: ReplicationCursor) -> None: + """Drops a replication slot if it exists.""" + try: + cur.drop_replication_slot(name) + logger.info(f'Successfully dropped replication slot "{name}".') + except psycopg2.errors.UndefinedObject: # the replication slot does not exist + logger.info( + f'Replication slot "{name}" cannot be dropped because it does not exist.' + ) + + +def drop_publication(name: str, cur: ReplicationCursor) -> None: + """Drops a publication if it exists.""" + esc_name = escape_postgres_identifier(name) + try: + cur.execute(f"DROP PUBLICATION {esc_name};") + cur.connection.commit() + logger.info(f"Successfully dropped publication {esc_name}.") + except psycopg2.errors.UndefinedObject: # the publication does not exist + logger.info( + f"Publication {esc_name} cannot be dropped because it does not exist." + ) + + +def persist_snapshot_table( + snapshot_name: str, + table_name: str, + schema_name: str, + cur: cursor, + include_columns: Optional[Sequence[str]] = None, +) -> str: + """Persists exported snapshot table state. + + Reads snapshot table content and copies it into new table. + """ + col_str = "*" + if include_columns is not None: + col_str = ", ".join(map(escape_postgres_identifier, include_columns)) + snapshot_table_name = f"{table_name}_snapshot_{snapshot_name}" + snapshot_qual_name = _make_qualified_table_name(snapshot_table_name, schema_name) + qual_name = _make_qualified_table_name(table_name, schema_name) + cur.execute( + f""" + START TRANSACTION ISOLATION LEVEL REPEATABLE READ; + SET TRANSACTION SNAPSHOT '{snapshot_name}'; + CREATE TABLE {snapshot_qual_name} AS SELECT {col_str} FROM {qual_name}; + """ + ) + cur.connection.commit() + logger.info(f"Successfully persisted snapshot table state in {snapshot_qual_name}.") + return snapshot_table_name + + +def snapshot_table_resource( + snapshot_table_name: str, + schema_name: str, + table_name: str, + write_disposition: TWriteDisposition, + columns: TTableSchemaColumns = None, + credentials: ConnectionStringCredentials = dlt.secrets.value, +) -> DltResource: + """Returns a resource for a persisted snapshot table. + + Can be used to perform an initial load of the table, so all data that + existed in the table prior to initializing replication is also captured. + """ + resource: DltResource = sql_table( + credentials=credentials, + table=snapshot_table_name, + schema=schema_name, + detect_precision_hints=True, + ) + primary_key = _get_pk(table_name, schema_name, credentials) + resource.apply_hints( + table_name=table_name, + write_disposition=write_disposition, + columns=columns, + primary_key=primary_key, + ) + return resource + + +def get_max_lsn( + slot_name: str, + options: Dict[str, str], + credentials: ConnectionStringCredentials, +) -> Optional[int]: + """Returns maximum Log Sequence Number (LSN) in replication slot. + + Returns None if the replication slot is empty. + Does not consume the slot, i.e. messages are not flushed. + Raises error if the replication slot or publication does not exist. + """ + # comma-separated value string + options_str = ", ".join( + f"'{x}'" for xs in list(map(list, options.items())) for x in xs # type: ignore[arg-type] + ) + cur = _get_conn(credentials).cursor() + cur.execute( + "SELECT MAX(lsn) - '0/0' AS max_lsn " # subtract '0/0' to convert pg_lsn type to int (https://stackoverflow.com/a/73738472) + f"FROM pg_logical_slot_peek_binary_changes('{slot_name}', NULL, NULL, {options_str});" + ) + lsn: int = cur.fetchone()[0] + cur.connection.close() + return lsn + + +def get_pub_ops( + pub_name: str, + credentials: ConnectionStringCredentials, +) -> Dict[str, bool]: + """Returns dictionary of DML operations and their publish status.""" + cur = _get_conn(credentials).cursor() + cur.execute( + f""" + SELECT pubinsert, pubupdate, pubdelete, pubtruncate + FROM pg_publication WHERE pubname = '{pub_name}' + """ + ) + result = cur.fetchone() + cur.connection.close() + if result is None: + raise ValueError(f'Publication "{pub_name}" does not exist.') + return { + "insert": result[0], + "update": result[1], + "delete": result[2], + "truncate": result[3], + } + + +def lsn_int_to_hex(lsn: int) -> str: + """Convert integer LSN to postgres' hexadecimal representation.""" + # https://stackoverflow.com/questions/66797767/lsn-external-representation. + return f"{lsn >> 32 & 4294967295:X}/{lsn & 4294967295:08X}" + + +def advance_slot( + upto_lsn: int, + slot_name: str, + credentials: ConnectionStringCredentials, +) -> None: + """Advances position in the replication slot. + + Flushes all messages upto (and including) the message with LSN = `upto_lsn`. + This function is used as alternative to psycopg2's `send_feedback` method, because + the behavior of that method seems odd when used outside of `consume_stream`. + """ + if upto_lsn != 0: + cur = _get_conn(credentials).cursor() + cur.execute( + f"SELECT * FROM pg_replication_slot_advance('{slot_name}', '{lsn_int_to_hex(upto_lsn)}');" + ) + cur.connection.close() + + +def _get_conn( + credentials: ConnectionStringCredentials, + connection_factory: Optional[Any] = None, +) -> Union[psycopg2.extensions.connection, LogicalReplicationConnection]: + """Returns a psycopg2 connection to interact with postgres.""" + return psycopg2.connect( # type: ignore[call-overload,no-any-return] + database=credentials.database, + user=credentials.username, + password=credentials.password, + host=credentials.host, + port=credentials.port, + connection_factory=connection_factory, + **credentials.query, + ) + + +def _get_rep_conn( + credentials: ConnectionStringCredentials, +) -> LogicalReplicationConnection: + """Returns a psycopg2 LogicalReplicationConnection to interact with postgres replication functionality. + + Raises error if the user does not have the REPLICATION attribute assigned. + """ + return _get_conn(credentials, LogicalReplicationConnection) # type: ignore[return-value] + + +def _make_qualified_table_name(table_name: str, schema_name: str) -> str: + """Escapes and combines a schema and table name.""" + return ( + escape_postgres_identifier(schema_name) + + "." + + escape_postgres_identifier(table_name) + ) + + +def _get_pk( + table_name: str, + schema_name: str, + credentials: ConnectionStringCredentials, +) -> Optional[TColumnNames]: + """Returns primary key column(s) for postgres table. + + Returns None if no primary key columns exist. + """ + qual_name = _make_qualified_table_name(table_name, schema_name) + cur = _get_conn(credentials).cursor() + # https://wiki.postgresql.org/wiki/Retrieve_primary_key_columns + cur.execute( + f""" + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '{qual_name}'::regclass + AND i.indisprimary; + """ + ) + result = [tup[0] for tup in cur.fetchall()] + cur.connection.close() + if len(result) == 0: + return None + elif len(result) == 1: + return result[0] # type: ignore[no-any-return] + return result + + +@dataclass +class ItemGenerator: + credentials: ConnectionStringCredentials + slot_name: str + options: Dict[str, str] + upto_lsn: int + start_lsn: int = 0 + target_batch_size: int = 1000 + include_columns: Optional[Dict[str, Sequence[str]]] = (None,) # type: ignore[assignment] + columns: Optional[Dict[str, TTableSchemaColumns]] = (None,) # type: ignore[assignment] + last_commit_lsn: Optional[int] = field(default=None, init=False) + generated_all: bool = False + + def __iter__(self) -> Iterator[Union[TDataItem, DataItemWithMeta]]: + """Yields replication messages from MessageConsumer. + + Starts replication of messages published by the publication from the replication slot. + Maintains LSN of last consumed Commit message in object state. + Does not advance the slot. + """ + try: + cur = _get_rep_conn(self.credentials).cursor() + cur.start_replication( + slot_name=self.slot_name, + start_lsn=self.start_lsn, + decode=False, + options=self.options, + ) + consumer = MessageConsumer( + upto_lsn=self.upto_lsn, + pub_ops=get_pub_ops( + self.options["publication_names"], self.credentials + ), + target_batch_size=self.target_batch_size, + include_columns=self.include_columns, + columns=self.columns, + ) + cur.consume_stream(consumer) + except StopReplication: # completed batch or reached `upto_lsn` + pass + finally: + cur.connection.close() + self.last_commit_lsn = consumer.last_commit_lsn + for rel_id, data_items in consumer.data_items.items(): + table_name = consumer.last_table_schema[rel_id]["name"] + yield data_items[0] # meta item with column hints only, no data + yield dlt.mark.with_table_name(data_items[1:], table_name) + self.generated_all = consumer.consumed_all + + +class MessageConsumer: + """Consumes messages from a ReplicationCursor sequentially. + + Generates data item for each `insert`, `update`, and `delete` message. + Processes in batches to limit memory usage. + Maintains message data needed by subsequent messages in internal state. + """ + + def __init__( + self, + upto_lsn: int, + pub_ops: Dict[str, bool], + target_batch_size: int = 1000, + include_columns: Optional[Dict[str, Sequence[str]]] = None, + columns: Optional[Dict[str, TTableSchemaColumns]] = None, + ) -> None: + self.upto_lsn = upto_lsn + self.pub_ops = pub_ops + self.target_batch_size = target_batch_size + self.include_columns = include_columns + self.columns = columns + + self.consumed_all: bool = False + # data_items attribute maintains all data items + self.data_items: Dict[ + int, List[Union[TDataItem, DataItemWithMeta]] + ] = dict() # maps relation_id to list of data items + # other attributes only maintain last-seen values + self.last_table_schema: Dict[ + int, TTableSchema + ] = dict() # maps relation_id to table schema + self.last_commit_ts: pendulum.DateTime + self.last_commit_lsn = None + + def __call__(self, msg: ReplicationMessage) -> None: + """Processes message received from stream.""" + self.process_msg(msg) + + def process_msg(self, msg: ReplicationMessage) -> None: + """Processes encoded replication message. + + Identifies message type and decodes accordingly. + Message treatment is different for various message types. + Breaks out of stream with StopReplication exception when + - `upto_lsn` is reached + - `target_batch_size` is reached + - a table's schema has changed + """ + op = msg.payload[:1] + if op == b"I": + self.process_change(Insert(msg.payload), msg.data_start) + elif op == b"U": + self.process_change(Update(msg.payload), msg.data_start) + elif op == b"D": + self.process_change(Delete(msg.payload), msg.data_start) + elif op == b"B": + self.last_commit_ts = Begin(msg.payload).commit_ts # type: ignore[assignment] + elif op == b"C": + self.process_commit(msg) + elif op == b"R": + self.process_relation(Relation(msg.payload)) + elif op == b"T": + logger.warning( + "The truncate operation is currently not supported. " + "Truncate replication messages are ignored." + ) + + def process_commit(self, msg: ReplicationMessage) -> None: + """Updates object state when Commit message is observed. + + Raises StopReplication when `upto_lsn` or `target_batch_size` is reached. + """ + self.last_commit_lsn = msg.data_start + if msg.data_start >= self.upto_lsn: + self.consumed_all = True + n_items = sum( + [len(items) for items in self.data_items.values()] + ) # combine items for all tables + if self.consumed_all or n_items >= self.target_batch_size: + raise StopReplication + + def process_relation(self, decoded_msg: Relation) -> None: + """Processes a replication message of type Relation. + + Stores table schema in object state. + Creates meta item to emit column hints while yielding data. + + Raises StopReplication when a table's schema changes. + """ + if ( + self.data_items.get(decoded_msg.relation_id) is not None + ): # table schema change + raise StopReplication + # get table schema information from source and store in object state + table_name = decoded_msg.relation_name + columns: TTableSchemaColumns = { + c.name: _to_dlt_column_schema(c) for c in decoded_msg.columns + } + self.last_table_schema[decoded_msg.relation_id] = { + "name": table_name, + "columns": columns, + } + + # apply user input + # 1) exclude columns + include_columns = ( + None + if self.include_columns is None + else self.include_columns.get(table_name) + ) + if include_columns is not None: + columns = {k: v for k, v in columns.items() if k in include_columns} + # 2) override source hints + column_hints: TTableSchemaColumns = ( + dict() if self.columns is None else self.columns.get(table_name, dict()) + ) + for column_name, column_val in column_hints.items(): + columns[column_name] = merge_column(columns[column_name], column_val) + + # add hints for replication columns + columns["lsn"] = {"data_type": "bigint", "nullable": True} + if self.pub_ops["update"] or self.pub_ops["delete"]: + columns["lsn"]["dedup_sort"] = "desc" + if self.pub_ops["delete"]: + columns["deleted_ts"] = { + "hard_delete": True, + "data_type": "timestamp", + "nullable": True, + } + + # determine write disposition + write_disposition: TWriteDisposition = "append" + if self.pub_ops["update"] or self.pub_ops["delete"]: + write_disposition = "merge" + + # include meta item to emit hints while yielding data + meta_item = dlt.mark.with_hints( + [], + dlt.mark.make_hints( + table_name=table_name, + write_disposition=write_disposition, + columns=columns, + ), + create_table_variant=True, + ) + self.data_items[decoded_msg.relation_id] = [meta_item] + + def process_change( + self, decoded_msg: Union[Insert, Update, Delete], msg_start_lsn: int + ) -> None: + """Processes replication message of type Insert, Update, or Delete. + + Adds data item for inserted/updated/deleted record to instance attribute. + """ + if isinstance(decoded_msg, (Insert, Update)): + column_data = decoded_msg.new_tuple.column_data + elif isinstance(decoded_msg, Delete): + column_data = decoded_msg.old_tuple.column_data + table_name = self.last_table_schema[decoded_msg.relation_id]["name"] + data_item = self.gen_data_item( + data=column_data, + column_schema=self.last_table_schema[decoded_msg.relation_id]["columns"], + lsn=msg_start_lsn, + commit_ts=self.last_commit_ts, + for_delete=isinstance(decoded_msg, Delete), + include_columns=None + if self.include_columns is None + else self.include_columns.get(table_name), + ) + self.data_items[decoded_msg.relation_id].append(data_item) + + @staticmethod + def gen_data_item( + data: List[ColumnData], + column_schema: TTableSchemaColumns, + lsn: int, + commit_ts: pendulum.DateTime, + for_delete: bool, + include_columns: Optional[Sequence[str]] = None, + ) -> TDataItem: + """Generates data item from replication message data and corresponding metadata.""" + data_item = { + schema["name"]: _to_dlt_val( + val=data.col_data, + data_type=schema["data_type"], + byte1=data.col_data_category, + for_delete=for_delete, + ) + for (schema, data) in zip(column_schema.values(), data) + if (True if include_columns is None else schema["name"] in include_columns) + } + data_item["lsn"] = lsn + if for_delete: + data_item["deleted_ts"] = commit_ts + return data_item diff --git a/sources/pg_replication/requirements.txt b/sources/pg_replication/requirements.txt new file mode 100644 index 000000000..7a49c8ab2 --- /dev/null +++ b/sources/pg_replication/requirements.txt @@ -0,0 +1,2 @@ +dlt>=0.4.8 +psycopg2-binary>=2.9.9 \ No newline at end of file diff --git a/sources/pg_replication/schema_types.py b/sources/pg_replication/schema_types.py new file mode 100644 index 000000000..a5758c32c --- /dev/null +++ b/sources/pg_replication/schema_types.py @@ -0,0 +1,119 @@ +import json +from typing import Optional, Any, Dict + +from dlt.common import Decimal +from dlt.common.data_types.typing import TDataType +from dlt.common.data_types.type_helpers import coerce_value +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.destinations.impl.postgres import capabilities +from dlt.destinations.impl.postgres.postgres import PostgresTypeMapper + +from .decoders import ColumnType + + +_DUMMY_VALS: Dict[TDataType, Any] = { + "bigint": 0, + "binary": b" ", + "bool": True, + "complex": [0], + "date": "2000-01-01", + "decimal": Decimal(0), + "double": 0.0, + "text": "", + "time": "00:00:00", + "timestamp": "2000-01-01T00:00:00", + "wei": 0, +} +"""Dummy values used to replace NULLs in NOT NULL colums in key-only delete records.""" + + +_PG_TYPES: Dict[int, str] = { + 16: "boolean", + 17: "bytea", + 20: "bigint", + 21: "smallint", + 23: "integer", + 701: "double precision", + 1043: "character varying", + 1082: "date", + 1083: "time without time zone", + 1184: "timestamp with time zone", + 1700: "numeric", + 3802: "jsonb", +} +"""Maps postgres type OID to type string. Only includes types present in PostgresTypeMapper.""" + + +def _get_precision(type_id: int, atttypmod: int) -> Optional[int]: + """Get precision from postgres type attributes.""" + # https://stackoverflow.com/a/3351120 + if type_id == 21: # smallint + return 16 + elif type_id == 23: # integer + return 32 + elif type_id == 20: # bigint + return 64 + if atttypmod != -1: + if type_id == 1700: # numeric + return ((atttypmod - 4) >> 16) & 65535 + elif type_id in ( + 1083, + 1184, + ): # time without time zone, timestamp with time zone + return atttypmod + elif type_id == 1043: # character varying + return atttypmod - 4 + return None + + +def _get_scale(type_id: int, atttypmod: int) -> Optional[int]: + """Get scale from postgres type attributes.""" + # https://stackoverflow.com/a/3351120 + if atttypmod != -1: + if type_id in (21, 23, 20): # smallint, integer, bigint + return 0 + if type_id == 1700: # numeric + return (atttypmod - 4) & 65535 + return None + + +def _to_dlt_column_type(type_id: int, atttypmod: int) -> TColumnType: + """Converts postgres type OID to dlt column type. + + Type OIDs not in _PG_TYPES mapping default to "text" type. + """ + pg_type = _PG_TYPES.get(type_id) + precision = _get_precision(type_id, atttypmod) + scale = _get_scale(type_id, atttypmod) + mapper = PostgresTypeMapper(capabilities()) + return mapper.from_db_type(pg_type, precision, scale) + + +def _to_dlt_column_schema(col: ColumnType) -> TColumnSchema: + """Converts pypgoutput ColumnType to dlt column schema.""" + dlt_column_type = _to_dlt_column_type(col.type_id, col.atttypmod) + partial_column_schema = { + "name": col.name, + "primary_key": bool(col.part_of_pkey), + } + return {**dlt_column_type, **partial_column_schema} # type: ignore[typeddict-item] + + +def _to_dlt_val(val: str, data_type: TDataType, byte1: str, for_delete: bool) -> Any: + """Converts pgoutput's text-formatted value into dlt-compatible data value.""" + if byte1 == "n": + if for_delete: + # replace None with dummy value to prevent NOT NULL violations in staging table + return _DUMMY_VALS[data_type] + return None + elif byte1 == "t": + if data_type == "binary": + # https://www.postgresql.org/docs/current/datatype-binary.html#DATATYPE-BINARY-BYTEA-HEX-FORMAT + return bytes.fromhex(val.replace("\\x", "")) + elif data_type == "complex": + return json.loads(val) + return coerce_value(data_type, "text", val) + else: + raise ValueError( + f"Byte1 in replication message must be 'n' or 't', not '{byte1}'." + ) diff --git a/sources/pg_replication_pipeline.py b/sources/pg_replication_pipeline.py new file mode 100644 index 000000000..811337ddc --- /dev/null +++ b/sources/pg_replication_pipeline.py @@ -0,0 +1,290 @@ +import dlt + + +from pg_replication import replication_resource +from pg_replication.helpers import init_replication + + +def replicate_single_table() -> None: + """Sets up replication for a single Postgres table and loads changes into a destination. + + Demonstrates basic usage of `init_replication` helper and `replication_resource` resource. + Uses `src_pl` to create and change the replicated Postgres table—this + is only for demonstration purposes, you won't need this when you run in production + as you'll probably have another process feeding your Postgres instance. + """ + # create source and destination pipelines + src_pl = dlt.pipeline( + pipeline_name="source_pipeline", + destination="postgres", + dataset_name="replicate_single_table", + full_refresh=True, + ) + dest_pl = dlt.pipeline( + pipeline_name="pg_replication_pipeline", + destination="duckdb", + dataset_name="replicate_single_table", + full_refresh=True, + ) + + # create table "my_source_table" in source to demonstrate replication + create_source_table( + src_pl, "CREATE TABLE {table_name} (id integer PRIMARY KEY, val bool);" + ) + + # initialize replication for the source table—this creates a replication slot and publication + slot_name = "example_slot" + pub_name = "example_pub" + init_replication( # requires the Postgres user to have the REPLICATION attribute assigned + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="my_source_table", + reset=True, + ) + + # create a resource that generates items for each change in the source table + changes = replication_resource(slot_name, pub_name) + + # insert two records in source table and propagate changes to destination + change_source_table( + src_pl, "INSERT INTO {table_name} VALUES (1, true), (2, false);" + ) + dest_pl.run(changes) + show_destination_table(dest_pl) + + # update record in source table and propagate change to destination + change_source_table(src_pl, "UPDATE {table_name} SET val = true WHERE id = 2;") + dest_pl.run(changes) + show_destination_table(dest_pl) + + # delete record from source table and propagate change to destination + change_source_table(src_pl, "DELETE FROM {table_name} WHERE id = 2;") + dest_pl.run(changes) + show_destination_table(dest_pl) + + +def replicate_with_initial_load() -> None: + """Sets up replication with initial load. + + Demonstrates usage of `persist_snapshots` argument and snapshot resource + returned by `init_replication` helper. + """ + # create source and destination pipelines + src_pl = dlt.pipeline( + pipeline_name="source_pipeline", + destination="postgres", + dataset_name="replicate_with_initial_load", + full_refresh=True, + ) + dest_pl = dlt.pipeline( + pipeline_name="pg_replication_pipeline", + destination="duckdb", + dataset_name="replicate_with_initial_load", + full_refresh=True, + ) + + # create table "my_source_table" in source to demonstrate replication + create_source_table( + src_pl, "CREATE TABLE {table_name} (id integer PRIMARY KEY, val bool);" + ) + + # insert records before initializing replication + change_source_table( + src_pl, "INSERT INTO {table_name} VALUES (1, true), (2, false);" + ) + + # initialize replication for the source table + slot_name = "example_slot" + pub_name = "example_pub" + snapshot = init_replication( # requires the Postgres user to have the REPLICATION attribute assigned + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="my_source_table", + persist_snapshots=True, # persist snapshot table(s) and let function return resource(s) for initial load + reset=True, + ) + + # perform initial load to capture all records present in source table prior to replication initialization + dest_pl.run(snapshot) + show_destination_table(dest_pl) + + # insert record in source table and propagate change to destination + change_source_table(src_pl, "INSERT INTO {table_name} VALUES (3, true);") + changes = replication_resource(slot_name, pub_name) + dest_pl.run(changes) + show_destination_table(dest_pl) + + +def replicate_entire_schema() -> None: + """Demonstrates setup and usage of schema replication. + + Schema replication requires a Postgres server version of 15 or higher. An + exception is raised if that's not the case. + """ + # create source and destination pipelines + src_pl = dlt.pipeline( + pipeline_name="source_pipeline", + destination="postgres", + dataset_name="replicate_entire_schema", + full_refresh=True, + ) + dest_pl = dlt.pipeline( + pipeline_name="pg_replication_pipeline", + destination="duckdb", + dataset_name="replicate_entire_schema", + full_refresh=True, + ) + + # create two source tables to demonstrate schema replication + create_source_table( + src_pl, + "CREATE TABLE {table_name} (id integer PRIMARY KEY, val bool);", + "tbl_x", + ) + create_source_table( + src_pl, + "CREATE TABLE {table_name} (id integer PRIMARY KEY, val varchar);", + "tbl_y", + ) + + # initialize schema replication by omitting the `table_names` argument + slot_name = "example_slot" + pub_name = "example_pub" + init_replication( # initializing schema replication requires the Postgres user to be a superuser + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + reset=True, + ) + + # create a resource that generates items for each change in the schema's tables + changes = replication_resource(slot_name, pub_name) + + # insert records in source tables and propagate changes to destination + change_source_table( + src_pl, "INSERT INTO {table_name} VALUES (1, true), (2, false);", "tbl_x" + ) + change_source_table(src_pl, "INSERT INTO {table_name} VALUES (1, 'foo');", "tbl_y") + dest_pl.run(changes) + show_destination_table(dest_pl, "tbl_x") + show_destination_table(dest_pl, "tbl_y") + + # tables added to the schema later are also included in the replication + create_source_table( + src_pl, "CREATE TABLE {table_name} (id integer PRIMARY KEY, val date);", "tbl_z" + ) + change_source_table( + src_pl, "INSERT INTO {table_name} VALUES (1, '2023-03-18');", "tbl_z" + ) + dest_pl.run(changes) + show_destination_table(dest_pl, "tbl_z") + + +def replicate_with_column_selection() -> None: + """Sets up replication with column selection. + + Demonstrates usage of `include_columns` argument. + """ + # create source and destination pipelines + src_pl = dlt.pipeline( + pipeline_name="source_pipeline", + destination="postgres", + dataset_name="replicate_with_column_selection", + full_refresh=True, + ) + dest_pl = dlt.pipeline( + pipeline_name="pg_replication_pipeline", + destination="duckdb", + dataset_name="replicate_with_column_selection", + full_refresh=True, + ) + + # create two source tables to demonstrate schema replication + create_source_table( + src_pl, + "CREATE TABLE {table_name} (c1 integer PRIMARY KEY, c2 bool, c3 varchar);", + "tbl_x", + ) + create_source_table( + src_pl, + "CREATE TABLE {table_name} (c1 integer PRIMARY KEY, c2 bool, c3 varchar);", + "tbl_y", + ) + + # initialize schema replication by omitting the `table_names` argument + slot_name = "example_slot" + pub_name = "example_pub" + init_replication( # requires the Postgres user to have the REPLICATION attribute assigned + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names=("tbl_x", "tbl_y"), + reset=True, + ) + + # create a resource that generates items for each change in the schema's tables + changes = replication_resource( + slot_name=slot_name, + pub_name=pub_name, + include_columns={ + "tbl_x": ("c1", "c2") + }, # columns not specified here are excluded from generated data items + ) + + # insert records in source tables and propagate changes to destination + change_source_table( + src_pl, "INSERT INTO {table_name} VALUES (1, true, 'foo');", "tbl_x" + ) + change_source_table( + src_pl, "INSERT INTO {table_name} VALUES (1, false, 'bar');", "tbl_y" + ) + dest_pl.run(changes) + + # show columns in schema for both tables + # column c3 is not in the schema for tbl_x because we did not include it + # tbl_y does have column c3 because we didn't specify include columns for this table and by default all columns are included + print("tbl_x", ":", list(dest_pl.default_schema.get_table_columns("tbl_x").keys())) + print("tbl_y", ":", list(dest_pl.default_schema.get_table_columns("tbl_y").keys())) + + +# define some helper methods to make examples more readable + + +def create_source_table( + src_pl: dlt.Pipeline, sql: str, table_name: str = "my_source_table" +) -> None: + with src_pl.sql_client() as c: + try: + c.create_dataset() + except dlt.destinations.exceptions.DatabaseTerminalException: + pass + qual_name = c.make_qualified_table_name(table_name) + c.execute_sql(sql.format(table_name=qual_name)) + + +def change_source_table( + src_pl: dlt.Pipeline, sql: str, table_name: str = "my_source_table" +) -> None: + with src_pl.sql_client() as c: + qual_name = c.make_qualified_table_name(table_name) + c.execute_sql(sql.format(table_name=qual_name)) + + +def show_destination_table( + dest_pl: dlt.Pipeline, + table_name: str = "my_source_table", + column_names: str = "id, val", +) -> None: + with dest_pl.sql_client() as c: + dest_qual_name = c.make_qualified_table_name(table_name) + dest_records = c.execute_sql(f"SELECT {column_names} FROM {dest_qual_name};") + print(table_name, ":", dest_records) + + +if __name__ == "__main__": + replicate_single_table() + # replicate_with_initial_load() + # replicate_entire_schema() + # replicate_with_column_selection() diff --git a/tests/pg_replication/__init__.py b/tests/pg_replication/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/pg_replication/cases.py b/tests/pg_replication/cases.py new file mode 100644 index 000000000..a17efcad7 --- /dev/null +++ b/tests/pg_replication/cases.py @@ -0,0 +1,94 @@ +from typing import List + +from dlt.common import Decimal +from dlt.common.schema import TColumnSchema, TTableSchemaColumns + + +TABLE_ROW_ALL_DATA_TYPES = { + "col1": 989127831, + "col2": 898912.821982, + "col3": True, + "col4": "2022-05-23T13:26:45.176451+00:00", + "col5": "string data \n \r \x8e 🦆", + "col6": Decimal("2323.34"), + "col7": b"binary data \n \r \x8e", + # "col8": 2**56 + 92093890840, # TODO: uncommment and make it work + "col9": { + "complex": [1, 2, 3, "a"], + "link": ( + "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\012 \6" + " \\vity%3A69'08444473\n\n551163392%2C6n \r \x8e9085" + ), + }, + "col10": "2023-02-27", + "col11": "13:26:45.176451", + "col1_null": None, + "col2_null": None, + "col3_null": None, + "col4_null": None, + "col5_null": None, + "col6_null": None, + "col7_null": None, + # "col8_null": None, + "col9_null": None, + "col10_null": None, + "col11_null": None, + "col1_precision": 22324, + "col4_precision": "2022-05-23T13:26:46.167231+00:00", + "col5_precision": "string data 2 \n \r \x8e 🦆", + "col6_precision": Decimal("2323.34"), + "col7_precision": b"binary data 2 \n \r \x8e", + "col11_precision": "13:26:45.176451", +} +TABLE_UPDATE: List[TColumnSchema] = [ + {"name": "col1", "data_type": "bigint", "nullable": False}, + {"name": "col2", "data_type": "double", "nullable": False}, + {"name": "col3", "data_type": "bool", "nullable": False}, + {"name": "col4", "data_type": "timestamp", "nullable": False}, + {"name": "col5", "data_type": "text", "nullable": False}, + {"name": "col6", "data_type": "decimal", "nullable": False}, + {"name": "col7", "data_type": "binary", "nullable": False}, + # {"name": "col8", "data_type": "wei", "nullable": False}, + {"name": "col9", "data_type": "complex", "nullable": False, "variant": True}, + {"name": "col10", "data_type": "date", "nullable": False}, + {"name": "col11", "data_type": "time", "nullable": False}, + {"name": "col1_null", "data_type": "bigint", "nullable": True}, + {"name": "col2_null", "data_type": "double", "nullable": True}, + {"name": "col3_null", "data_type": "bool", "nullable": True}, + {"name": "col4_null", "data_type": "timestamp", "nullable": True}, + {"name": "col5_null", "data_type": "text", "nullable": True}, + {"name": "col6_null", "data_type": "decimal", "nullable": True}, + {"name": "col7_null", "data_type": "binary", "nullable": True}, + # {"name": "col8_null", "data_type": "wei", "nullable": True}, + {"name": "col9_null", "data_type": "complex", "nullable": True, "variant": True}, + {"name": "col10_null", "data_type": "date", "nullable": True}, + {"name": "col11_null", "data_type": "time", "nullable": True}, + { + "name": "col1_precision", + "data_type": "bigint", + "precision": 16, + "nullable": False, + }, + { + "name": "col4_precision", + "data_type": "timestamp", + "precision": 3, + "nullable": False, + }, + {"name": "col5_precision", "data_type": "text", "precision": 25, "nullable": False}, + { + "name": "col6_precision", + "data_type": "decimal", + "precision": 6, + "scale": 2, + "nullable": False, + }, + { + "name": "col7_precision", + "data_type": "binary", + "precision": 19, + "nullable": False, + }, + {"name": "col11_precision", "data_type": "time", "precision": 3, "nullable": False}, +] +TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]: t for t in TABLE_UPDATE} diff --git a/tests/pg_replication/conftest.py b/tests/pg_replication/conftest.py new file mode 100644 index 000000000..ed74cc110 --- /dev/null +++ b/tests/pg_replication/conftest.py @@ -0,0 +1,43 @@ +import pytest + +from typing import Iterator, Tuple + +import dlt +from dlt.common.utils import uniq_id + + +@pytest.fixture() +def src_config() -> Iterator[Tuple[dlt.Pipeline, str, str]]: + # random slot and pub to enable parallel runs + slot = "test_slot_" + uniq_id(4) + pub = "test_pub" + uniq_id(4) + # setup + src_pl = dlt.pipeline( + pipeline_name="src_pl", + destination="postgres", + full_refresh=True, + credentials=dlt.secrets.get("sources.pg_replication.credentials"), + ) + yield src_pl, slot, pub + # teardown + with src_pl.sql_client() as c: + # drop tables + try: + c.drop_dataset() + except Exception as e: + print(e) + with c.with_staging_dataset(staging=True): + try: + c.drop_dataset() + except Exception as e: + print(e) + # drop replication slot + try: + c.execute_sql(f"SELECT pg_drop_replication_slot('{slot}');") + except Exception as e: + print(e) + # drop publication + try: + c.execute_sql(f"DROP PUBLICATION IF EXISTS {pub};") + except Exception as e: + print(e) diff --git a/tests/pg_replication/test_pg_replication.py b/tests/pg_replication/test_pg_replication.py new file mode 100644 index 000000000..6d63fa3bc --- /dev/null +++ b/tests/pg_replication/test_pg_replication.py @@ -0,0 +1,866 @@ +import pytest + +from typing import Set, Tuple +from copy import deepcopy +from psycopg2.errors import InsufficientPrivilege + +import dlt +from dlt.destinations.job_client_impl import SqlJobClientBase + +from tests.utils import ( + ALL_DESTINATIONS, + assert_load_info, + load_table_counts, + get_table_metrics, +) +from sources.pg_replication import replication_resource +from sources.pg_replication.helpers import init_replication, get_pg_version +from sources.pg_replication.exceptions import IncompatiblePostgresVersionException + +from .cases import TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA +from .utils import add_pk, assert_loaded_data, is_super_user + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_core_functionality( + src_config: Tuple[dlt.Pipeline, str, str], destination_name: str +) -> None: + @dlt.resource(write_disposition="merge", primary_key="id_x") + def tbl_x(data): + yield data + + @dlt.resource(write_disposition="merge", primary_key="id_y") + def tbl_y(data): + yield data + + src_pl, slot_name, pub_name = src_config + + src_pl.run( + [ + tbl_x({"id_x": 1, "val_x": "foo"}), + tbl_y({"id_y": 1, "val_y": True}), + ] + ) + add_pk(src_pl.sql_client, "tbl_x", "id_x") + add_pk(src_pl.sql_client, "tbl_y", "id_y") + + snapshots = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names=("tbl_x", "tbl_y"), + persist_snapshots=True, + ) + + changes = replication_resource(slot_name, pub_name) + + src_pl.run( + [ + tbl_x([{"id_x": 2, "val_x": "bar"}, {"id_x": 3, "val_x": "baz"}]), + tbl_y({"id_y": 2, "val_y": False}), + ] + ) + + dest_pl = dlt.pipeline( + pipeline_name="dest_pl", destination=destination_name, full_refresh=True + ) + + # initial load + info = dest_pl.run(snapshots) + assert_load_info(info) + assert load_table_counts(dest_pl, "tbl_x", "tbl_y") == {"tbl_x": 1, "tbl_y": 1} + exp_tbl_x = [{"id_x": 1, "val_x": "foo"}] + exp_tbl_y = [{"id_y": 1, "val_y": True}] + assert_loaded_data(dest_pl, "tbl_x", ["id_x", "val_x"], exp_tbl_x, "id_x") + assert_loaded_data(dest_pl, "tbl_y", ["id_y", "val_y"], exp_tbl_y, "id_y") + + # process changes + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "tbl_x", "tbl_y") == {"tbl_x": 3, "tbl_y": 2} + exp_tbl_x = [ + {"id_x": 1, "val_x": "foo"}, + {"id_x": 2, "val_x": "bar"}, + {"id_x": 3, "val_x": "baz"}, + ] + exp_tbl_y = [{"id_y": 1, "val_y": True}, {"id_y": 2, "val_y": False}] + assert_loaded_data(dest_pl, "tbl_x", ["id_x", "val_x"], exp_tbl_x, "id_x") + assert_loaded_data(dest_pl, "tbl_y", ["id_y", "val_y"], exp_tbl_y, "id_y") + + # change single table + src_pl.run(tbl_y({"id_y": 3, "val_y": True})) + + # process changes + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "tbl_x", "tbl_y") == {"tbl_x": 3, "tbl_y": 3} + exp_tbl_y = [ + {"id_y": 1, "val_y": True}, + {"id_y": 2, "val_y": False}, + {"id_y": 3, "val_y": True}, + ] + assert_loaded_data(dest_pl, "tbl_x", ["id_x", "val_x"], exp_tbl_x, "id_x") + assert_loaded_data(dest_pl, "tbl_y", ["id_y", "val_y"], exp_tbl_y, "id_y") + + # update tables + with src_pl.sql_client() as c: + qual_name = src_pl.sql_client().make_qualified_table_name("tbl_x") + c.execute_sql(f"UPDATE {qual_name} SET val_x = 'foo_updated' WHERE id_x = 1;") + qual_name = src_pl.sql_client().make_qualified_table_name("tbl_y") + c.execute_sql(f"UPDATE {qual_name} SET val_y = false WHERE id_y = 1;") + + # process changes + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "tbl_x", "tbl_y") == {"tbl_x": 3, "tbl_y": 3} + exp_tbl_x = [ + {"id_x": 1, "val_x": "foo_updated"}, + {"id_x": 2, "val_x": "bar"}, + {"id_x": 3, "val_x": "baz"}, + ] + exp_tbl_y = [ + {"id_y": 1, "val_y": False}, + {"id_y": 2, "val_y": False}, + {"id_y": 3, "val_y": True}, + ] + assert_loaded_data(dest_pl, "tbl_x", ["id_x", "val_x"], exp_tbl_x, "id_x") + assert_loaded_data(dest_pl, "tbl_y", ["id_y", "val_y"], exp_tbl_y, "id_y") + + # delete from table + with src_pl.sql_client() as c: + qual_name = src_pl.sql_client().make_qualified_table_name("tbl_x") + c.execute_sql(f"DELETE FROM {qual_name} WHERE id_x = 1;") + + # process changes + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "tbl_x", "tbl_y") == {"tbl_x": 2, "tbl_y": 3} + exp_tbl_x = [{"id_x": 2, "val_x": "bar"}, {"id_x": 3, "val_x": "baz"}] + exp_tbl_y = [ + {"id_y": 1, "val_y": False}, + {"id_y": 2, "val_y": False}, + {"id_y": 3, "val_y": True}, + ] + assert_loaded_data(dest_pl, "tbl_x", ["id_x", "val_x"], exp_tbl_x, "id_x") + assert_loaded_data(dest_pl, "tbl_y", ["id_y", "val_y"], exp_tbl_y, "id_y") + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_without_init_load( + src_config: Tuple[dlt.Pipeline, str, str], destination_name: str +) -> None: + @dlt.resource(write_disposition="merge", primary_key="id_x") + def tbl_x(data): + yield data + + @dlt.resource(write_disposition="merge", primary_key="id_y") + def tbl_y(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create postgres table + # since we're skipping initial load, these records should not be in the replicated table + src_pl.run( + [ + tbl_x({"id_x": 1, "val_x": "foo"}), + tbl_y({"id_y": 1, "val_y": True}), + ] + ) + add_pk(src_pl.sql_client, "tbl_x", "id_x") + add_pk(src_pl.sql_client, "tbl_y", "id_y") + + # initialize replication and create resource for changes + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names=("tbl_x", "tbl_y"), + ) + changes = replication_resource(slot_name, pub_name) + + # change postgres table after replication has been initialized + # these records should be in the replicated table + src_pl.run( + [ + tbl_x([{"id_x": 2, "val_x": "bar"}, {"id_x": 3, "val_x": "baz"}]), + tbl_y({"id_y": 2, "val_y": False}), + ] + ) + + # load changes to destination and assert expectations + dest_pl = dlt.pipeline( + pipeline_name="dest_pl", destination=destination_name, full_refresh=True + ) + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "tbl_x", "tbl_y") == {"tbl_x": 2, "tbl_y": 1} + exp_tbl_x = [{"id_x": 2, "val_x": "bar"}, {"id_x": 3, "val_x": "baz"}] + exp_tbl_y = [{"id_y": 2, "val_y": False}] + assert_loaded_data(dest_pl, "tbl_x", ["id_x", "val_x"], exp_tbl_x, "id_x") + assert_loaded_data(dest_pl, "tbl_y", ["id_y", "val_y"], exp_tbl_y, "id_y") + + # delete from table + with src_pl.sql_client() as c: + qual_name = src_pl.sql_client().make_qualified_table_name("tbl_x") + c.execute_sql(f"DELETE FROM {qual_name} WHERE id_x = 2;") + + # process change and assert expectations + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "tbl_x", "tbl_y") == {"tbl_x": 1, "tbl_y": 1} + exp_tbl_x = [{"id_x": 3, "val_x": "baz"}] + exp_tbl_y = [{"id_y": 2, "val_y": False}] + assert_loaded_data(dest_pl, "tbl_x", ["id_x", "val_x"], exp_tbl_x, "id_x") + assert_loaded_data(dest_pl, "tbl_y", ["id_y", "val_y"], exp_tbl_y, "id_y") + + +def test_insert_only(src_config: Tuple[dlt.Pipeline, str, str]) -> None: + def items(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create postgres table with single record + src_pl.run(items({"id": 1, "foo": "bar"})) + + # initialize replication and create resource for changes + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="items", + publish="insert", + ) + changes = replication_resource(slot_name, pub_name) + + # insert a record in postgres table + src_pl.run(items({"id": 2, "foo": "bar"})) + + # extract items from resource + dest_pl = dlt.pipeline(pipeline_name="dest_pl", full_refresh=True) + extract_info = dest_pl.extract(changes) + assert get_table_metrics(extract_info, "items")["items_count"] == 1 + + # do an update and a delete—these operations should not lead to items in the resource + with src_pl.sql_client() as c: + qual_name = src_pl.sql_client().make_qualified_table_name("items") + c.execute_sql(f"UPDATE {qual_name} SET foo = 'baz' WHERE id = 2;") + c.execute_sql(f"DELETE FROM {qual_name} WHERE id = 2;") + extract_info = dest_pl.extract(changes) + assert ( + get_table_metrics(extract_info, "items") is None + ) # there should be no metrics for the "items" table + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("give_hints", [True, False]) +@pytest.mark.parametrize("init_load", [True, False]) +def test_mapped_data_types( + src_config: Tuple[dlt.Pipeline, str, str], + destination_name: str, + give_hints: bool, + init_load: bool, +) -> None: + """Assert common data types (the ones mapped in PostgresTypeMapper) are properly handled.""" + + data = deepcopy(TABLE_ROW_ALL_DATA_TYPES) + column_schema = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) + + # resource to load data into postgres source table + @dlt.resource(primary_key="col1", write_disposition="merge", columns=column_schema) + def items(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create postgres table with single record containing all data types + src_pl.run(items(data)) + add_pk(src_pl.sql_client, "items", "col1") + + # initialize replication and create resources + snapshot = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="items", + persist_snapshots=init_load, + columns={"items": column_schema} if give_hints else None, + ) + + changes = replication_resource( + slot_name=slot_name, + pub_name=pub_name, + columns={"items": column_schema} if give_hints else None, + ) + + # initial load + dest_pl = dlt.pipeline( + pipeline_name="dest_pl", destination=destination_name, full_refresh=True + ) + if init_load: + info = dest_pl.run(snapshot) + assert_load_info(info) + assert load_table_counts(dest_pl, "items")["items"] == 1 + + # insert two records in postgres table + r1 = deepcopy(data) + r2 = deepcopy(data) + r1["col1"] = 1 + r2["col1"] = 2 + src_pl.run(items([r1, r2])) + + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "items")["items"] == 3 if init_load else 2 + + if give_hints: + # compare observed with expected column types + observed = dest_pl.default_schema.get_table("items")["columns"] + for name, expected in column_schema.items(): + assert observed[name]["data_type"] == expected["data_type"] + # postgres bytea does not have precision + if ( + expected.get("precision") is not None + and expected["data_type"] != "binary" + ): + assert observed[name]["precision"] == expected["precision"] + + # update two records in postgres table + # this does two deletes and two inserts because dlt implements "merge" as "delete-and-insert" + # as such, postgres will create four replication messages: two of type Delete and two of type Insert + r1["col2"] = 1.5 + r2["col3"] = False + src_pl.run(items([r1, r2])) + + # process changes and assert expectations + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "items")["items"] == 3 if init_load else 2 + exp = [ + {"col1": 1, "col2": 1.5, "col3": True}, + {"col1": 2, "col2": 898912.821982, "col3": False}, + { + "col1": 989127831, + "col2": 898912.821982, + "col3": True, + }, # only present with init load + ] + if not init_load: + del exp[-1] + assert_loaded_data(dest_pl, "items", ["col1", "col2", "col3"], exp, "col1") + + # now do an actual update, so postgres will create a replication message of type Update + with src_pl.sql_client() as c: + qual_name = src_pl.sql_client().make_qualified_table_name("items") + c.execute_sql(f"UPDATE {qual_name} SET col2 = 2.5 WHERE col1 = 2;") + + # process change and assert expectation + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "items")["items"] == 3 if init_load else 2 + exp = [{"col1": 2, "col2": 2.5, "col3": False}] + assert_loaded_data( + dest_pl, "items", ["col1", "col2", "col3"], exp, "col1", "col1 = 2" + ) + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_unmapped_data_types( + src_config: Tuple[dlt.Pipeline, str, str], destination_name: str +) -> None: + """Assert postgres data types that aren't explicitly mapped default to "text" type.""" + src_pl, slot_name, pub_name = src_config + + # create postgres table with some unmapped types + with src_pl.sql_client() as c: + c.create_dataset() + c.execute_sql( + "CREATE TABLE data_types (bit_col bit(1), box_col box, uuid_col uuid);" + ) + + # initialize replication and create resource + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="data_types", + publish="insert", + ) + changes = replication_resource(slot_name, pub_name) + + # insert record in source table to create replication item + with src_pl.sql_client() as c: + c.execute_sql( + "INSERT INTO data_types VALUES (B'1', box '((1,1), (0,0))', gen_random_uuid());" + ) + + # run destination pipeline and assert resulting data types + dest_pl = dlt.pipeline( + pipeline_name="dest_pl", destination=destination_name, full_refresh=True + ) + dest_pl.extract(changes) + dest_pl.normalize() + columns = dest_pl.default_schema.get_table_columns("data_types") + assert columns["bit_col"]["data_type"] == "text" + assert columns["box_col"]["data_type"] == "text" + assert columns["uuid_col"]["data_type"] == "text" + + +@pytest.mark.parametrize("publish", ["insert", "insert, update, delete"]) +def test_write_disposition( + src_config: Tuple[dlt.Pipeline, str, str], publish: str +) -> None: + @dlt.resource + def items(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create postgres table + src_pl.run(items({"id": 1, "val": True})) + + # create resources + snapshot = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="items", + publish=publish, + persist_snapshots=True, + ) + + # assert write disposition on snapshot resource + expected_write_disposition = "append" if publish == "insert" else "merge" + assert snapshot.write_disposition == expected_write_disposition + + # assert write disposition on tables dispatched by changes resource + changes = replication_resource(slot_name, pub_name) + src_pl.run(items({"id": 2, "val": True})) + dest_pl = dlt.pipeline(pipeline_name="dest_pl", full_refresh=True) + dest_pl.extract(changes) + assert ( + dest_pl.default_schema.get_table("items")["write_disposition"] + == expected_write_disposition + ) + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("init_load", [True, False]) +def test_include_columns( + src_config: Tuple[dlt.Pipeline, str, str], destination_name: str, init_load: bool +) -> None: + def get_cols(pipeline: dlt.Pipeline, table_name: str) -> set: + with pipeline.destination_client(pipeline.default_schema_name) as client: + client: SqlJobClientBase + return { + k + for k in client.get_storage_table(table_name)[1].keys() + if not k.startswith("_dlt_") + } + + @dlt.resource + def tbl_x(data): + yield data + + @dlt.resource + def tbl_y(data): + yield data + + @dlt.resource + def tbl_z(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create three postgres tables + src_pl.run( + [ + tbl_x({"id_x": 1, "val_x": "foo", "another_col_x": 1}), + tbl_y({"id_y": 1, "val_y": "foo", "another_col_y": 1}), + tbl_z({"id_z": 1, "val_z": "foo", "another_col_z": 1}), + ] + ) + + # initialize replication and create resources + include_columns = { + "tbl_x": ["id_x", "val_x"], + "tbl_y": ["id_y", "val_y"], + # tbl_z is not specified, hence all columns should be included + } + snapshots = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names=("tbl_x", "tbl_y", "tbl_z"), + publish="insert", + persist_snapshots=init_load, + include_columns=include_columns, + ) + changes = replication_resource( + slot_name=slot_name, pub_name=pub_name, include_columns=include_columns + ) + + # update three postgres tables + src_pl.run( + [ + tbl_x({"id_x": 2, "val_x": "foo", "another_col_x": 1}), + tbl_y({"id_y": 2, "val_y": "foo", "another_col_y": 1}), + tbl_z({"id_z": 2, "val_z": "foo", "another_col_z": 1}), + ] + ) + + # load to destination and assert column expectations + dest_pl = dlt.pipeline( + pipeline_name="dest_pl", destination=destination_name, full_refresh=True + ) + if init_load: + dest_pl.run(snapshots) + assert get_cols(dest_pl, "tbl_x") == {"id_x", "val_x"} + assert get_cols(dest_pl, "tbl_y") == {"id_y", "val_y"} + assert get_cols(dest_pl, "tbl_z") == {"id_z", "val_z", "another_col_z"} + dest_pl.run(changes) + assert get_cols(dest_pl, "tbl_x") == {"id_x", "val_x", "lsn"} + assert get_cols(dest_pl, "tbl_y") == {"id_y", "val_y", "lsn"} + assert get_cols(dest_pl, "tbl_z") == {"id_z", "val_z", "another_col_z", "lsn"} + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("init_load", [True, False]) +def test_column_hints( + src_config: Tuple[dlt.Pipeline, str, str], destination_name: str, init_load: bool +) -> None: + @dlt.resource + def tbl_x(data): + yield data + + @dlt.resource + def tbl_y(data): + yield data + + @dlt.resource + def tbl_z(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create three postgres tables + src_pl.run( + [ + tbl_x({"id_x": 1, "val_x": "foo", "another_col_x": 1}), + tbl_y({"id_y": 1, "val_y": "foo", "another_col_y": 1}), + tbl_z({"id_z": 1, "val_z": "foo", "another_col_z": 1}), + ] + ) + + # initialize replication and create resources + column_hints = { + "tbl_x": {"another_col_x": {"data_type": "double"}}, + "tbl_y": {"another_col_y": {"precision": 32}}, + # tbl_z is not specified, hence all columns should be included + } + snapshots = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names=("tbl_x", "tbl_y", "tbl_z"), + publish="insert", + persist_snapshots=init_load, + columns=column_hints, + ) + changes = replication_resource( + slot_name=slot_name, pub_name=pub_name, columns=column_hints + ) + + # update three postgres tables + src_pl.run( + [ + tbl_x({"id_x": 2, "val_x": "foo", "another_col_x": 1}), + tbl_y({"id_y": 2, "val_y": "foo", "another_col_y": 1}), + tbl_z({"id_z": 2, "val_z": "foo", "another_col_z": 1}), + ] + ) + + # load to destination and assert column expectations + dest_pl = dlt.pipeline( + pipeline_name="dest_pl", destination=destination_name, full_refresh=True + ) + if init_load: + dest_pl.run(snapshots) + assert ( + dest_pl.default_schema.get_table_columns("tbl_x")["another_col_x"][ + "data_type" + ] + == "double" + ) + assert ( + dest_pl.default_schema.get_table_columns("tbl_y")["another_col_y"][ + "precision" + ] + == 32 + ) + assert ( + dest_pl.default_schema.get_table_columns("tbl_z")["another_col_z"][ + "data_type" + ] + == "bigint" + ) + dest_pl.run(changes) + assert ( + dest_pl.default_schema.get_table_columns("tbl_x")["another_col_x"]["data_type"] + == "double" + ) + assert ( + dest_pl.default_schema.get_table_columns("tbl_y")["another_col_y"]["precision"] + == 32 + ) + assert ( + dest_pl.default_schema.get_table_columns("tbl_z")["another_col_z"]["data_type"] + == "bigint" + ) + + # the tests below should pass, but they don't because of a bug that causes + # column hints to be added to other tables when dispatching to multiple tables + assert "another_col_x" not in dest_pl.default_schema.get_table_columns("tbl_y") + assert "another_col_x" not in dest_pl.default_schema.get_table_columns("tbl_z") + assert "another_col_y" not in dest_pl.default_schema.get_table_columns( + "tbl_x", include_incomplete=True + ) + assert "another_col_y" not in dest_pl.default_schema.get_table_columns( + "tbl_z", include_incomplete=True + ) + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +def test_table_schema_change( + src_config: Tuple[dlt.Pipeline, str, str], destination_name: str +) -> None: + src_pl, slot_name, pub_name = src_config + + # create postgres table + src_pl.run([{"c1": 1, "c2": 1}], table_name="items") + + # initialize replication + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="items", + publish="insert", + ) + + # create resource and pipeline + changes = replication_resource(slot_name, pub_name) + dest_pl = dlt.pipeline( + pipeline_name="dest_pl", destination=destination_name, full_refresh=True + ) + + # add a column in one commit, this will create one Relation message + src_pl.run([{"c1": 2, "c2": 1}, {"c1": 3, "c2": 1, "c3": 1}], table_name="items") + info = dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "items") == {"items": 2} + exp = [{"c1": 2, "c2": 1, "c3": None}, {"c1": 3, "c2": 1, "c3": 1}] + assert_loaded_data(dest_pl, "items", ["c1", "c2", "c3"], exp, "c1") + + # add a column in two commits, this will create two Relation messages + src_pl.run([{"c1": 4, "c2": 1, "c3": 1}], table_name="items") + src_pl.run([{"c1": 5, "c2": 1, "c3": 1, "c4": 1}], table_name="items") + dest_pl.run(changes) + assert_load_info(info) + assert load_table_counts(dest_pl, "items") == {"items": 4} + exp = [ + {"c1": 4, "c2": 1, "c3": 1, "c4": None}, + {"c1": 5, "c2": 1, "c3": 1, "c4": 1}, + ] + assert_loaded_data( + dest_pl, "items", ["c1", "c2", "c3", "c4"], exp, "c1", "c1 IN (4, 5)" + ) + + +def test_init_replication(src_config: Tuple[dlt.Pipeline, str, str]) -> None: + def get_table_names_in_pub() -> Set[str]: + with src_pl.sql_client() as c: + result = c.execute_sql( + f"SELECT tablename FROM pg_publication_tables WHERE pubname = '{pub_name}';" + ) + return {tup[0] for tup in result} + + @dlt.resource + def tbl_x(data): + yield data + + @dlt.resource + def tbl_y(data): + yield data + + @dlt.resource + def tbl_z(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create three postgres tables + src_pl.run( + [ + tbl_x({"id_x": 1, "val_x": "foo"}), + tbl_y({"id_y": 1, "val_y": "foo"}), + tbl_z({"id_z": 1, "val_z": "foo"}), + ] + ) + + # initialize replication with a single table + snapshot = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="tbl_x", + persist_snapshots=True, + ) + assert snapshot is not None + assert get_table_names_in_pub() == {"tbl_x"} + + # adding another table is supported, but snapshot tables won't be persisted + snapshots = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names=("tbl_x", "tbl_y"), + persist_snapshots=True, + ) + assert snapshots is None + assert get_table_names_in_pub() == {"tbl_x", "tbl_y"} + + # removing a table is not supported + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="tbl_x", # "tbl_y" is no longer provided + ) + # "tbl_y" is still in the publication + assert get_table_names_in_pub() == {"tbl_x", "tbl_y"} + + # switching to whole schema replication is supported by omitting `table_names`, + # but only for Postgres server versions 15 or higher and with superuser privileges + is_su = is_super_user(src_pl.sql_client) + if get_pg_version() >= 150000 and is_su: + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + ) + # includes dlt system tables + assert get_table_names_in_pub() >= {"tbl_x", "tbl_y", "tbl_z"} + else: + exp_err = ( + InsufficientPrivilege if not is_su else IncompatiblePostgresVersionException + ) + with pytest.raises(exp_err): + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + ) + + +def test_replicate_schema(src_config: Tuple[dlt.Pipeline, str, str]) -> None: + if get_pg_version() < 150000: + pytest.skip("incompatible Postgres server version") + if not is_super_user(src_config[0].sql_client): + pytest.skip("Postgres user needs to be superuser") + + @dlt.resource + def tbl_x(data): + yield data + + @dlt.resource + def tbl_y(data): + yield data + + @dlt.resource + def tbl_z(data): + yield data + + src_pl, slot_name, pub_name = src_config + + # create two postgres tables + src_pl.run( + [ + tbl_x({"id_x": 1, "val_x": "foo"}), + tbl_y({"id_y": 1, "val_y": "foo"}), + ] + ) + + # initialize replication and create resource + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, # we only specify `schema_name`, not `table_names` + publish="insert", + ) + changes = replication_resource(slot_name, pub_name) + + # change source tables and load to destination + src_pl.run( + [ + tbl_x({"id_x": 2, "val_x": "foo"}), + tbl_y({"id_y": 2, "val_y": "foo"}), + ] + ) + dest_pl = dlt.pipeline(pipeline_name="dest_pl", full_refresh=True) + dest_pl.extract(changes) + assert set(dest_pl.default_schema.data_table_names()) == {"tbl_x", "tbl_y"} + + # introduce new table in source and assert it gets included in the replication + src_pl.run( + [ + tbl_x({"id_x": 3, "val_x": "foo"}), + tbl_y({"id_y": 3, "val_y": "foo"}), + tbl_z({"id_z": 1, "val_z": "foo"}), + ] + ) + dest_pl.extract(changes) + assert set(dest_pl.default_schema.data_table_names()) == {"tbl_x", "tbl_y", "tbl_z"} + + +def test_batching(src_config: Tuple[dlt.Pipeline, str, str]) -> None: + # this test asserts the number of data items yielded by the replication resource + # is not affected by `target_batch_size` and the number of replication messages per transaction + src_pl, slot_name, pub_name = src_config + + # create postgres table with single record + data = {"id": 1000, "val": True} + src_pl.run([data], table_name="items") + + # initialize replication and create resource for changes + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="items", + ) + changes = replication_resource(slot_name, pub_name, target_batch_size=50) + + # create destination pipeline and resource + dest_pl = dlt.pipeline(pipeline_name="dest_pl", full_refresh=True) + + # insert 100 records into source table in one transaction + batch = [{**r, **{"id": key}} for r in [data] for key in range(1, 101)] + src_pl.run(batch, table_name="items") + extract_info = dest_pl.extract(changes) + assert extract_info.asdict()["job_metrics"][0]["items_count"] == 100 + + # insert 100 records into source table in 5 transactions + batch = [{**r, **{"id": key}} for r in [data] for key in range(101, 121)] + src_pl.run(batch, table_name="items") + batch = [{**r, **{"id": key}} for r in [data] for key in range(121, 141)] + src_pl.run(batch, table_name="items") + batch = [{**r, **{"id": key}} for r in [data] for key in range(141, 161)] + src_pl.run(batch, table_name="items") + batch = [{**r, **{"id": key}} for r in [data] for key in range(161, 181)] + src_pl.run(batch, table_name="items") + batch = [{**r, **{"id": key}} for r in [data] for key in range(181, 201)] + src_pl.run(batch, table_name="items") + extract_info = dest_pl.extract(changes) + assert extract_info.asdict()["job_metrics"][0]["items_count"] == 100 diff --git a/tests/pg_replication/utils.py b/tests/pg_replication/utils.py new file mode 100644 index 000000000..fe7695b91 --- /dev/null +++ b/tests/pg_replication/utils.py @@ -0,0 +1,52 @@ +from typing import Sequence, List, Dict, Any, Optional + +import dlt +from dlt import Pipeline +from dlt.common.data_writers.escape import escape_postgres_identifier +from dlt.common.configuration.specs import ConnectionStringCredentials + +from tests.utils import select_data + + +def add_pk(sql_client, table_name: str, column_name: str) -> None: + """Adds primary key to postgres table. + + In the context of replication, the primary key serves as REPLICA IDENTITY. + A REPLICA IDENTITY is required when publishing UPDATEs and/or DELETEs. + """ + with sql_client() as c: + qual_name = c.make_qualified_table_name(table_name) + c.execute_sql(f"ALTER TABLE {qual_name} ADD PRIMARY KEY ({column_name});") + + +def assert_loaded_data( + pipeline: Pipeline, + table_name: str, + column_names: Sequence[str], + expectation: List[Dict[str, Any]], + sort_column_name: str, + where_clause: Optional[str] = None, +) -> None: + """Asserts loaded data meets expectation.""" + qual_name = pipeline.sql_client().make_qualified_table_name(table_name) + escape_id = pipeline.destination_client().capabilities.escape_identifier + column_str = ", ".join(map(escape_id, column_names)) + qry = f"SELECT {column_str} FROM {qual_name}" + if where_clause is not None: + qry += " WHERE " + where_clause + observation = [ + {column_name: row[idx] for idx, column_name in enumerate(column_names)} + for row in select_data(pipeline, qry) + ] + assert sorted(observation, key=lambda d: d[sort_column_name]) == expectation + + +def is_super_user(sql_client) -> bool: + """Returns True if Postgres user is superuser, False otherwise.""" + username = dlt.secrets.get( + "sources.pg_replication.credentials", ConnectionStringCredentials + ).username + with sql_client() as c: + return c.execute_sql( + f"SELECT rolsuper FROM pg_roles WHERE rolname = '{username}';" + )[0][0] diff --git a/tests/utils.py b/tests/utils.py index 5f61ecdb7..a872c03d5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,7 @@ import os import platform import pytest -from typing import Any, Dict, Iterator, List, Set +from typing import Any, Iterator, List, Sequence, Dict, Optional, Set from os import environ from unittest.mock import patch @@ -18,7 +18,7 @@ ConfigTomlProvider, SecretsTomlProvider, ) -from dlt.common.pipeline import LoadInfo, PipelineContext +from dlt.common.pipeline import LoadInfo, PipelineContext, ExtractInfo from dlt.common.storages import FileStorage from dlt.common.schema.typing import TTableSchema @@ -227,6 +227,27 @@ def load_table_distinct_counts( return {r[0]: r[1] for r in rows} +def select_data( + p: dlt.Pipeline, sql: str, schema_name: str = None +) -> List[Sequence[Any]]: + """Returns select `sql` results as list.""" + with p.sql_client(schema_name=schema_name) as c: + with c.execute_query(sql) as cur: + return list(cur.fetchall()) + + +def get_table_metrics( + extract_info: ExtractInfo, table_name: str +) -> Optional[Dict[str, Any]]: + """Returns table metrics from ExtractInfo object.""" + table_metrics_list = [ + d + for d in extract_info.asdict()["table_metrics"] + if d["table_name"] == table_name + ] + return None if len(table_metrics_list) == 0 else table_metrics_list[0] + + def load_data_table_counts(p: dlt.Pipeline) -> DictStrAny: """Returns counts for all the data tables in default schema of `p` (excluding dlt tables)""" tables = [table["name"] for table in p.default_schema.data_tables()]