From 91167ff6bc264f2ad8bb145924b9c7058dffa5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:16:10 -0600 Subject: [PATCH] chore: Add pre-commit config (#65) --- .gitignore | 2 +- .pre-commit-config.yaml | 25 ++ LICENSE | 1 - ruff.toml | 23 ++ setup.py | 53 +-- tap_salesforce/__init__.py | 388 ++++++++++---------- tap_salesforce/salesforce/__init__.py | 441 ++++++++++++----------- tap_salesforce/salesforce/bulk.py | 225 ++++++------ tap_salesforce/salesforce/bulk2.py | 56 ++- tap_salesforce/salesforce/credentials.py | 54 ++- tap_salesforce/salesforce/exceptions.py | 7 +- tap_salesforce/salesforce/rest.py | 63 ++-- tap_salesforce/sync.py | 125 ++++--- 13 files changed, 748 insertions(+), 715 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 ruff.toml diff --git a/.gitignore b/.gitignore index 4ca50f7..b542176 100644 --- a/.gitignore +++ b/.gitignore @@ -98,4 +98,4 @@ env.sh config.json .autoenv.zsh -*~ \ No newline at end of file +*~ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..17ee9f3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +ci: + autofix_prs: false + autoupdate_schedule: monthly + autoupdate_commit_msg: 'chore: pre-commit autoupdate' + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-json + - id: check-toml + exclude: | + (?x)^( + copier_template/.*/pyproject.toml + )$ + - id: end-of-file-fixer + exclude: (copier_template/.*|docs/.*|samples/.*\.json) + - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/LICENSE b/LICENSE index 627c3e9..4ec8c3f 100644 --- a/LICENSE +++ b/LICENSE @@ -617,4 +617,3 @@ Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS - \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..f418bb5 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,23 @@ +line-length = 120 +target-version = "py37" + +[lint] +select = [ + "F", # pyflakes + "E", # pycodestyle (errors) + "W", # pycodestyle (warnings) + "C90", # mccabe + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "DTZ", # flake8-datetimez + "T10", # flake8-debugger + "PIE", # flake8-pie + "YTT", # flake8-2020 + "T20", # flake8-print + "SIM", # flake8-simplify + "PERF", # Perflint + "RUF", # Ruff-specific rules +] diff --git a/setup.py b/setup.py index d5237f1..df72e66 100644 --- a/setup.py +++ b/setup.py @@ -2,32 +2,33 @@ from setuptools import setup -setup(name='tap-salesforce', - version='1.7.0', - description='Singer.io tap for extracting data from the Salesforce API', - author='Stitch', - url='https://singer.io', - classifiers=['Programming Language :: Python :: 3 :: Only'], - py_modules=['tap_salesforce'], - install_requires=[ - 'requests==2.32.2', - 'singer-python~=5.13', - 'xmltodict==0.11.0', - 'simple-salesforce<1.0', # v1.0 requires `requests==2.22.0` - # fix version conflicts, see https://gitlab.com/meltano/meltano/issues/193 - 'idna==3.7', - 'cryptography', - 'pyOpenSSL', - ], - entry_points=''' +setup( + name="tap-salesforce", + version="1.7.0", + description="Singer.io tap for extracting data from the Salesforce API", + author="Stitch", + url="https://singer.io", + classifiers=["Programming Language :: Python :: 3 :: Only"], + py_modules=["tap_salesforce"], + install_requires=[ + "requests==2.32.2", + "singer-python~=5.13", + "xmltodict==0.11.0", + "simple-salesforce<1.0", # v1.0 requires `requests==2.22.0` + # fix version conflicts, see https://gitlab.com/meltano/meltano/issues/193 + "idna==3.7", + "cryptography", + "pyOpenSSL", + ], + entry_points=""" [console_scripts] tap-salesforce=tap_salesforce:main - ''', - packages=['tap_salesforce', 'tap_salesforce.salesforce'], - package_data = { - 'tap_salesforce/schemas': [ - # add schema.json filenames here - ] - }, - include_package_data=True, + """, + packages=["tap_salesforce", "tap_salesforce.salesforce"], + package_data={ + "tap_salesforce/schemas": [ + # add schema.json filenames here + ] + }, + include_package_data=True, ) diff --git a/tap_salesforce/__init__.py b/tap_salesforce/__init__.py index 973eaa1..950f595 100644 --- a/tap_salesforce/__init__.py +++ b/tap_salesforce/__init__.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 from __future__ import annotations + import asyncio import concurrent.futures import json @@ -11,21 +12,26 @@ from singer import metadata, metrics import tap_salesforce.salesforce -from tap_salesforce.sync import (sync_stream, resume_syncing_bulk_query, get_stream_version) from tap_salesforce.salesforce import Salesforce -from tap_salesforce.salesforce.exceptions import ( - TapSalesforceException, TapSalesforceQuotaExceededException) from tap_salesforce.salesforce.credentials import ( OAuthCredentials, PasswordCredentials, - parse_credentials + parse_credentials, +) +from tap_salesforce.salesforce.exceptions import ( + TapSalesforceExceptionError, + TapSalesforceQuotaExceededError, +) +from tap_salesforce.sync import ( + get_stream_version, + resume_syncing_bulk_query, + sync_stream, ) LOGGER = singer.get_logger() # the tap requires these keys -REQUIRED_CONFIG_KEYS = ['api_type', - 'select_fields_by_default'] +REQUIRED_CONFIG_KEYS = ["api_type", "select_fields_by_default"] # and either one of these credentials @@ -42,100 +48,95 @@ PASSWORD_CONFIG_KEYS = PasswordCredentials._fields CONFIG = { - 'refresh_token': None, - 'client_id': None, - 'client_secret': None, - 'start_date': None + "refresh_token": None, + "client_id": None, + "client_secret": None, + "start_date": None, } FORCED_FULL_TABLE = { - 'BackgroundOperationResult' # Does not support ordering by CreatedDate + "BackgroundOperationResult" # Does not support ordering by CreatedDate } + def get_replication_key(sobject_name, fields): if sobject_name in FORCED_FULL_TABLE: return None - fields_list = [f['name'] for f in fields] + fields_list = [f["name"] for f in fields] - if 'SystemModstamp' in fields_list: - return 'SystemModstamp' - elif 'LastModifiedDate' in fields_list: - return 'LastModifiedDate' - elif 'CreatedDate' in fields_list: - return 'CreatedDate' - elif 'LoginTime' in fields_list and sobject_name == 'LoginHistory': - return 'LoginTime' + if "SystemModstamp" in fields_list: + return "SystemModstamp" + elif "LastModifiedDate" in fields_list: + return "LastModifiedDate" + elif "CreatedDate" in fields_list: + return "CreatedDate" + elif "LoginTime" in fields_list and sobject_name == "LoginHistory": + return "LoginTime" return None + def stream_is_selected(mdata): - return mdata.get((), {}).get('selected', False) + return mdata.get((), {}).get("selected", False) + def build_state(raw_state, catalog): state = {} - for catalog_entry in catalog['streams']: - tap_stream_id = catalog_entry['tap_stream_id'] - catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_method = catalog_metadata.get((), {}).get('replication-method') + for catalog_entry in catalog["streams"]: + tap_stream_id = catalog_entry["tap_stream_id"] + catalog_metadata = metadata.to_map(catalog_entry["metadata"]) + replication_method = catalog_metadata.get((), {}).get("replication-method") - version = singer.get_bookmark(raw_state, - tap_stream_id, - 'version') + version = singer.get_bookmark(raw_state, tap_stream_id, "version") # Preserve state that deals with resuming an incomplete bulk job - if singer.get_bookmark(raw_state, tap_stream_id, 'JobID'): - job_id = singer.get_bookmark(raw_state, tap_stream_id, 'JobID') - batches = singer.get_bookmark(raw_state, tap_stream_id, 'BatchIDs') - current_bookmark = singer.get_bookmark(raw_state, tap_stream_id, 'JobHighestBookmarkSeen') - state = singer.write_bookmark(state, tap_stream_id, 'JobID', job_id) - state = singer.write_bookmark(state, tap_stream_id, 'BatchIDs', batches) - state = singer.write_bookmark(state, tap_stream_id, 'JobHighestBookmarkSeen', current_bookmark) - - if replication_method == 'INCREMENTAL': - replication_key = catalog_metadata.get((), {}).get('replication-key') - replication_key_value = singer.get_bookmark(raw_state, - tap_stream_id, - replication_key) + if singer.get_bookmark(raw_state, tap_stream_id, "JobID"): + job_id = singer.get_bookmark(raw_state, tap_stream_id, "JobID") + batches = singer.get_bookmark(raw_state, tap_stream_id, "BatchIDs") + current_bookmark = singer.get_bookmark(raw_state, tap_stream_id, "JobHighestBookmarkSeen") + state = singer.write_bookmark(state, tap_stream_id, "JobID", job_id) + state = singer.write_bookmark(state, tap_stream_id, "BatchIDs", batches) + state = singer.write_bookmark(state, tap_stream_id, "JobHighestBookmarkSeen", current_bookmark) + + if replication_method == "INCREMENTAL": + replication_key = catalog_metadata.get((), {}).get("replication-key") + replication_key_value = singer.get_bookmark(raw_state, tap_stream_id, replication_key) if version is not None: - state = singer.write_bookmark( - state, tap_stream_id, 'version', version) + state = singer.write_bookmark(state, tap_stream_id, "version", version) if replication_key_value is not None: - state = singer.write_bookmark( - state, tap_stream_id, replication_key, replication_key_value) - elif replication_method == 'FULL_TABLE' and version is None: - state = singer.write_bookmark(state, tap_stream_id, 'version', version) + state = singer.write_bookmark(state, tap_stream_id, replication_key, replication_key_value) + elif replication_method == "FULL_TABLE" and version is None: + state = singer.write_bookmark(state, tap_stream_id, "version", version) return state + # pylint: disable=undefined-variable def create_property_schema(field, mdata): - field_name = field['name'] + field_name = field["name"] if field_name == "Id": - mdata = metadata.write( - mdata, ('properties', field_name), 'inclusion', 'automatic') + mdata = metadata.write(mdata, ("properties", field_name), "inclusion", "automatic") else: - mdata = metadata.write( - mdata, ('properties', field_name), 'inclusion', 'available') + mdata = metadata.write(mdata, ("properties", field_name), "inclusion", "available") - property_schema, mdata = salesforce.field_to_property_schema(field, mdata) + property_schema, mdata = tap_salesforce.salesforce.field_to_property_schema(field, mdata) return (property_schema, mdata) -# pylint: disable=too-many-branches,too-many-statements -def do_discover(sf: Salesforce, streams: list[str]): +def do_discover(sf: Salesforce, streams: list[str]): # noqa: C901 if not streams: """Describes a Salesforce instance's objects and generates a JSON schema for each field.""" - LOGGER.info(f"Start discovery for all streams") + LOGGER.info("Start discovery for all streams") global_description = sf.describe() - objects_to_discover = {o['name'] for o in global_description['sobjects']} + objects_to_discover = {o["name"] for o in global_description["sobjects"]} else: LOGGER.info(f"Start discovery: {streams=}") objects_to_discover = streams - key_properties = ['Id'] + key_properties = ["Id"] sf_custom_setting_objects = [] object_to_tag_references = {} @@ -143,11 +144,9 @@ def do_discover(sf: Salesforce, streams: list[str]): # For each SF Object describe it, loop its fields and build a schema entries = [] for sobject_name in objects_to_discover: - # Skip blacklisted SF objects depending on the api_type in use # ChangeEvent objects are not queryable via Bulk or REST (undocumented) - if sobject_name in sf.get_blacklisted_objects() \ - or sobject_name.endswith("ChangeEvent"): + if sobject_name in sf.get_blacklisted_objects() or sobject_name.endswith("ChangeEvent"): continue sobject_description = sf.describe(sobject_name) @@ -159,13 +158,13 @@ def do_discover(sf: Salesforce, streams: list[str]): elif sobject_name.endswith("__Tag"): relationship_field = next( (f for f in sobject_description["fields"] if f.get("relationshipName") == "Item"), - None) + None, + ) if relationship_field: # Map {"Object":"Object__Tag"} - object_to_tag_references[relationship_field["referenceTo"] - [0]] = sobject_name + object_to_tag_references[relationship_field["referenceTo"][0]] = sobject_name - fields = sobject_description['fields'] + fields = sobject_description["fields"] replication_key = get_replication_key(sobject_name, fields) unsupported_fields = set() @@ -176,129 +175,126 @@ def do_discover(sf: Salesforce, streams: list[str]): # Loop over the object's fields for f in fields: - field_name = f['name'] - field_type = f['type'] + field_name = f["name"] + field_type = f["type"] # noqa: F841 if field_name == "Id": found_id_field = True - property_schema, mdata = create_property_schema( - f, mdata) + property_schema, mdata = create_property_schema(f, mdata) # Compound Address fields cannot be queried by the Bulk API - if f['type'] in ("address", "location") and sf.api_type in [tap_salesforce.salesforce.BULK_API_TYPE, tap_salesforce.salesforce.BULK2_API_TYPE]: - unsupported_fields.add( - (field_name, 'cannot query compound address fields with bulk API')) + if f["type"] in ("address", "location") and sf.api_type in [ + tap_salesforce.salesforce.BULK_API_TYPE, + tap_salesforce.salesforce.BULK2_API_TYPE, + ]: + unsupported_fields.add((field_name, "cannot query compound address fields with bulk API")) # we haven't been able to observe any records with a json field, so we # are marking it as unavailable until we have an example to work with - if f['type'] == "json": + if f["type"] == "json": unsupported_fields.add( - (field_name, 'do not currently support json fields - please contact support')) + ( + field_name, + "do not currently support json fields - please contact support", + ) + ) # Blacklisted fields are dependent on the api_type being used field_pair = (sobject_name, field_name) if field_pair in sf.get_blacklisted_fields(): - unsupported_fields.add( - (field_name, sf.get_blacklisted_fields()[field_pair])) + unsupported_fields.add((field_name, sf.get_blacklisted_fields()[field_pair])) - inclusion = metadata.get( - mdata, ('properties', field_name), 'inclusion') + inclusion = metadata.get(mdata, ("properties", field_name), "inclusion") - if sf.select_fields_by_default and inclusion != 'unsupported': - mdata = metadata.write( - mdata, ('properties', field_name), 'selected-by-default', True) + if sf.select_fields_by_default and inclusion != "unsupported": + mdata = metadata.write(mdata, ("properties", field_name), "selected-by-default", True) properties[field_name] = property_schema if replication_key: - mdata = metadata.write( - mdata, ('properties', replication_key), 'inclusion', 'automatic') + mdata = metadata.write(mdata, ("properties", replication_key), "inclusion", "automatic") # There are cases where compound fields are referenced by the associated # subfields but are not actually present in the field list - field_name_set = {f['name'] for f in fields} + field_name_set = {f["name"] for f in fields} filtered_unsupported_fields = [f for f in unsupported_fields if f[0] in field_name_set] missing_unsupported_field_names = [f[0] for f in unsupported_fields if f[0] not in field_name_set] if missing_unsupported_field_names: - LOGGER.info("Ignoring the following unsupported fields for object %s as they are missing from the field list: %s", - sobject_name, - ', '.join(sorted(missing_unsupported_field_names))) + LOGGER.info( + "Ignoring the following unsupported fields for object %s as they are missing from the field list: %s", + sobject_name, + ", ".join(sorted(missing_unsupported_field_names)), + ) if filtered_unsupported_fields: - LOGGER.info("Not syncing the following unsupported fields for object %s: %s", - sobject_name, - ', '.join(sorted([k for k, _ in filtered_unsupported_fields]))) + LOGGER.info( + "Not syncing the following unsupported fields for object %s: %s", + sobject_name, + ", ".join(sorted([k for k, _ in filtered_unsupported_fields])), + ) # Salesforce Objects are skipped when they do not have an Id field if not found_id_field: - LOGGER.info( - "Skipping Salesforce Object %s, as it has no Id field", - sobject_name) + LOGGER.info("Skipping Salesforce Object %s, as it has no Id field", sobject_name) continue # Any property added to unsupported_fields has metadata generated and # removed for prop, description in filtered_unsupported_fields: - if metadata.get(mdata, ('properties', prop), - 'selected-by-default'): - metadata.delete( - mdata, ('properties', prop), 'selected-by-default') + if metadata.get(mdata, ("properties", prop), "selected-by-default"): + metadata.delete(mdata, ("properties", prop), "selected-by-default") - mdata = metadata.write( - mdata, ('properties', prop), 'unsupported-description', description) - mdata = metadata.write( - mdata, ('properties', prop), 'inclusion', 'unsupported') + mdata = metadata.write(mdata, ("properties", prop), "unsupported-description", description) + mdata = metadata.write(mdata, ("properties", prop), "inclusion", "unsupported") if replication_key: - mdata = metadata.write( - mdata, (), 'valid-replication-keys', [replication_key]) - mdata = metadata.write( - mdata, (), 'replication-key', replication_key - ) - mdata = metadata.write( - mdata, (), 'replication-method', "INCREMENTAL" - ) + mdata = metadata.write(mdata, (), "valid-replication-keys", [replication_key]) + mdata = metadata.write(mdata, (), "replication-key", replication_key) + mdata = metadata.write(mdata, (), "replication-method", "INCREMENTAL") else: mdata = metadata.write( mdata, (), - 'forced-replication-method', + "forced-replication-method", { - 'replication-method': 'FULL_TABLE', - 'reason': 'No replication keys found from the Salesforce API'}) + "replication-method": "FULL_TABLE", + "reason": "No replication keys found from the Salesforce API", + }, + ) - mdata = metadata.write(mdata, (), 'table-key-properties', key_properties) + mdata = metadata.write(mdata, (), "table-key-properties", key_properties) schema = { - 'type': 'object', - 'additionalProperties': False, - 'properties': properties + "type": "object", + "additionalProperties": False, + "properties": properties, } entry = { - 'stream': sobject_name, - 'tap_stream_id': sobject_name, - 'schema': schema, - 'metadata': metadata.to_list(mdata) + "stream": sobject_name, + "tap_stream_id": sobject_name, + "schema": schema, + "metadata": metadata.to_list(mdata), } entries.append(entry) # For each custom setting field, remove its associated tag from entries # See Blacklisting.md for more information - unsupported_tag_objects = [object_to_tag_references[f] - for f in sf_custom_setting_objects if f in object_to_tag_references] + unsupported_tag_objects = [ + object_to_tag_references[f] for f in sf_custom_setting_objects if f in object_to_tag_references + ] if unsupported_tag_objects: - LOGGER.info( #pylint:disable=logging-not-lazy - "Skipping the following Tag objects, Tags on Custom Settings Salesforce objects " + - "are not supported by the Bulk API:") + LOGGER.info( # pylint:disable=logging-not-lazy + "Skipping the following Tag objects, Tags on Custom Settings Salesforce objects " + + "are not supported by the Bulk API:" + ) LOGGER.info(unsupported_tag_objects) - entries = [e for e in entries if e['stream'] - not in unsupported_tag_objects] + entries = [e for e in entries if e["stream"] not in unsupported_tag_objects] - result = {'streams': entries} + result = {"streams": entries} json.dump(result, sys.stdout, indent=4) @@ -316,10 +312,10 @@ def is_object_type(property_schema): return False -def is_property_selected( # noqa: C901 # ignore 'too complex' - stream_name, - metadata_map, - breadcrumb +def is_property_selected( # noqa: C901 + stream_name, + metadata_map, + breadcrumb, ) -> bool: """ Return True if the property is selected for extract. @@ -331,7 +327,7 @@ def is_property_selected( # noqa: C901 # ignore 'too complex' """ breadcrumb = breadcrumb or () if isinstance(breadcrumb, str): - breadcrumb = tuple([breadcrumb]) + breadcrumb = (breadcrumb,) if not metadata: # Default to true if no metadata to say otherwise @@ -341,9 +337,7 @@ def is_property_selected( # noqa: C901 # ignore 'too complex' parent_value = None if len(breadcrumb) > 0: parent_breadcrumb = tuple(list(breadcrumb)[:-2]) - parent_value = is_property_selected( - stream_name, metadata_map, parent_breadcrumb - ) + parent_value = is_property_selected(stream_name, metadata_map, parent_breadcrumb) if parent_value is False: return parent_value @@ -354,8 +348,7 @@ def is_property_selected( # noqa: C901 # ignore 'too complex' if inclusion == "unsupported": if selected is True: LOGGER.debug( - "Property '%s' was selected but is not supported. " - "Ignoring selected==True input.", + "Property '%s' was selected but is not supported. " "Ignoring selected==True input.", ":".join(breadcrumb), ) return False @@ -376,8 +369,7 @@ def is_property_selected( # noqa: C901 # ignore 'too complex' return selected_by_default LOGGER.debug( - "Selection metadata omitted for '%s':'%s'. " - "Using parent value of selected=%s.", + "Selection metadata omitted for '%s':'%s'. " "Using parent value of selected=%s.", stream_name, breadcrumb, parent_value, @@ -385,48 +377,36 @@ def is_property_selected( # noqa: C901 # ignore 'too complex' return parent_value or False -def pop_deselected_schema( - schema, - stream_name, - breadcrumb, - metadata_map -): +def pop_deselected_schema(schema, stream_name, breadcrumb, metadata_map): """Remove anything from schema that is not selected. Walk through schema, starting at the index in breadcrumb, recursively updating in place. This code is based on https://github.com/meltano/sdk/blob/c9c0967b0caca51fe7c87082f9e7c5dd54fa5dfa/singer_sdk/helpers/_catalog.py#L146 """ for property_name, val in list(schema.get("properties", {}).items()): - property_breadcrumb = tuple( - list(breadcrumb) + ["properties", property_name] - ) - selected = is_property_selected( - stream_name, metadata_map, property_breadcrumb - ) - LOGGER.info(stream_name + '.' + property_name + ' - ' + str(selected)) + property_breadcrumb = (*list(breadcrumb), "properties", property_name) + selected = is_property_selected(stream_name, metadata_map, property_breadcrumb) + LOGGER.info(stream_name + "." + property_name + " - " + str(selected)) if not selected: schema["properties"].pop(property_name) continue if is_object_type(val): # call recursively in case any subproperties are deselected. - pop_deselected_schema( - val, stream_name, property_breadcrumb, metadata_map - ) + pop_deselected_schema(val, stream_name, property_breadcrumb, metadata_map) async def sync_catalog_entry(sf, catalog_entry, state): stream_version = get_stream_version(catalog_entry, state) - stream = catalog_entry['stream'] - stream_alias = catalog_entry.get('stream_alias') + stream = catalog_entry["stream"] + stream_alias = catalog_entry.get("stream_alias") stream_name = catalog_entry["tap_stream_id"] - activate_version_message = singer.ActivateVersionMessage( - stream=(stream_alias or stream), version=stream_version) + activate_version_message = singer.ActivateVersionMessage(stream=(stream_alias or stream), version=stream_version) - catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + catalog_metadata = metadata.to_map(catalog_entry["metadata"]) + replication_key = catalog_metadata.get((), {}).get("replication-key") - mdata = metadata.to_map(catalog_entry['metadata']) + mdata = metadata.to_map(catalog_entry["metadata"]) if not stream_is_selected(mdata): LOGGER.debug("%s: Skipping - not selected", stream_name) @@ -435,57 +415,58 @@ async def sync_catalog_entry(sf, catalog_entry, state): LOGGER.info("%s: Starting", stream_name) singer.write_state(state) - key_properties = metadata.to_map(catalog_entry['metadata']).get((), {}).get('table-key-properties') + key_properties = metadata.to_map(catalog_entry["metadata"]).get((), {}).get("table-key-properties") # Filter the schema for selected fields - schema = deepcopy(catalog_entry['schema']) + schema = deepcopy(catalog_entry["schema"]) pop_deselected_schema(schema, stream_name, (), mdata) - singer.write_schema( - stream, - schema, - key_properties, - replication_key, - stream_alias) + singer.write_schema(stream, schema, key_properties, replication_key, stream_alias) loop = asyncio.get_event_loop() - job_id = singer.get_bookmark(state, catalog_entry['tap_stream_id'], 'JobID') + job_id = singer.get_bookmark(state, catalog_entry["tap_stream_id"], "JobID") if job_id: with metrics.record_counter(stream) as counter: - LOGGER.info("Found JobID from previous Bulk Query. Resuming sync for job: %s", job_id) + LOGGER.info( + "Found JobID from previous Bulk Query. Resuming sync for job: %s", + job_id, + ) # Resuming a sync should clear out the remaining state once finished - await loop.run_in_executor(None, resume_syncing_bulk_query, sf, catalog_entry, job_id, state, counter) - LOGGER.info("Completed sync for %s", stream_name) - state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('JobID', None) - state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('BatchIDs', None) - bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('JobHighestBookmarkSeen', None) - state = singer.write_bookmark( + await loop.run_in_executor( + None, + resume_syncing_bulk_query, + sf, + catalog_entry, + job_id, state, - catalog_entry['tap_stream_id'], - replication_key, - bookmark) + counter, + ) + LOGGER.info("Completed sync for %s", stream_name) + state.get("bookmarks", {}).get(catalog_entry["tap_stream_id"], {}).pop("JobID", None) + state.get("bookmarks", {}).get(catalog_entry["tap_stream_id"], {}).pop("BatchIDs", None) + bookmark = ( + state.get("bookmarks", {}).get(catalog_entry["tap_stream_id"], {}).pop("JobHighestBookmarkSeen", None) + ) + state = singer.write_bookmark(state, catalog_entry["tap_stream_id"], replication_key, bookmark) singer.write_state(state) else: - state_msg_threshold = CONFIG.get('state_message_threshold', 1000) + state_msg_threshold = CONFIG.get("state_message_threshold", 1000) # Tables with a replication_key or an empty bookmark will emit an # activate_version at the beginning of their sync - bookmark_is_empty = state.get('bookmarks', {}).get( - catalog_entry['tap_stream_id']) is None + bookmark_is_empty = state.get("bookmarks", {}).get(catalog_entry["tap_stream_id"]) is None if replication_key or bookmark_is_empty: singer.write_message(activate_version_message) - state = singer.write_bookmark(state, - catalog_entry['tap_stream_id'], - 'version', - stream_version) + state = singer.write_bookmark(state, catalog_entry["tap_stream_id"], "version", stream_version) await loop.run_in_executor(None, sync_stream, sf, catalog_entry, state, state_msg_threshold) LOGGER.info("Completed sync for %s", stream_name) + def do_sync(sf, catalog, state): LOGGER.info("Starting sync") - max_workers = CONFIG.get('max_workers', 8) + max_workers = CONFIG.get("max_workers", 8) executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) loop = asyncio.get_event_loop() loop.set_default_executor(executor) @@ -495,8 +476,7 @@ def do_sync(sf, catalog, state): # Schedule one task for each catalog entry to be extracted # and run them concurrently. - sync_tasks = (sync_catalog_entry(sf, catalog_entry, state) - for catalog_entry in streams_to_sync) + sync_tasks = (sync_catalog_entry(sf, catalog_entry, state) for catalog_entry in streams_to_sync) tasks = asyncio.gather(*sync_tasks) loop.run_until_complete(tasks) finally: @@ -506,6 +486,7 @@ def do_sync(sf, catalog, state): singer.write_state(state) LOGGER.info("Finished sync") + def main_impl(): args = singer_utils.parse_args(REQUIRED_CONFIG_KEYS) CONFIG.update(args.config) @@ -515,12 +496,13 @@ def main_impl(): try: sf = Salesforce( credentials=credentials, - quota_percent_total=CONFIG.get('quota_percent_total'), - quota_percent_per_run=CONFIG.get('quota_percent_per_run'), - is_sandbox=CONFIG.get('is_sandbox'), - select_fields_by_default=CONFIG.get('select_fields_by_default'), - default_start_date=CONFIG.get('start_date'), - api_type=CONFIG.get('api_type')) + quota_percent_total=CONFIG.get("quota_percent_total"), + quota_percent_per_run=CONFIG.get("quota_percent_per_run"), + is_sandbox=CONFIG.get("is_sandbox"), + select_fields_by_default=CONFIG.get("select_fields_by_default"), + default_start_date=CONFIG.get("start_date"), + api_type=CONFIG.get("api_type"), + ) sf.login() if args.discover: @@ -534,11 +516,13 @@ def main_impl(): if sf.rest_requests_attempted > 0: LOGGER.debug( "This job used %s REST requests towards the Salesforce quota.", - sf.rest_requests_attempted) + sf.rest_requests_attempted, + ) if sf.jobs_completed > 0: LOGGER.debug( "Replication used %s Bulk API jobs towards the Salesforce quota.", - sf.jobs_completed) + sf.jobs_completed, + ) if sf.auth.login_timer: sf.auth.login_timer.cancel() @@ -546,10 +530,10 @@ def main_impl(): def main(): try: main_impl() - except TapSalesforceQuotaExceededException as e: + except TapSalesforceQuotaExceededError as e: LOGGER.critical(e) sys.exit(2) - except TapSalesforceException as e: + except TapSalesforceExceptionError as e: LOGGER.critical(e) sys.exit(1) except Exception as e: diff --git a/tap_salesforce/salesforce/__init__.py b/tap_salesforce/salesforce/__init__.py index f1ff070..1622a38 100644 --- a/tap_salesforce/salesforce/__init__.py +++ b/tap_salesforce/salesforce/__init__.py @@ -1,22 +1,21 @@ import re -import time +from datetime import timedelta + import backoff import requests -from datetime import timedelta -from requests.exceptions import RequestException import singer import singer.utils as singer_utils from singer import metadata, metrics from tap_salesforce.salesforce.bulk import Bulk from tap_salesforce.salesforce.bulk2 import Bulk2 -from tap_salesforce.salesforce.rest import Rest -from tap_salesforce.salesforce.exceptions import ( - TapSalesforceException, - TapSalesforceQuotaExceededException, - SFDCCustomNotAcceptableError) from tap_salesforce.salesforce.credentials import SalesforceAuth - +from tap_salesforce.salesforce.exceptions import ( + SFDCCustomNotAcceptableError, + TapSalesforceExceptionError, + TapSalesforceQuotaExceededError, +) +from tap_salesforce.salesforce.rest import Rest LOGGER = singer.get_logger() @@ -24,120 +23,117 @@ BULK2_API_TYPE = "BULK2" REST_API_TYPE = "REST" -STRING_TYPES = set([ - 'id', - 'string', - 'picklist', - 'textarea', - 'phone', - 'url', - 'reference', - 'multipicklist', - 'combobox', - 'encryptedstring', - 'email', - 'complexvalue', # TODO: Unverified - 'masterrecord', - 'datacategorygroupreference', - 'base64' -]) - -NUMBER_TYPES = set([ - 'double', - 'currency', - 'percent' -]) - -DATE_TYPES = set([ - 'datetime', - 'date' -]) - -BINARY_TYPES = set([ - 'byte' -]) - -LOOSE_TYPES = set([ - 'anyType', - +STRING_TYPES = { + "id", + "string", + "picklist", + "textarea", + "phone", + "url", + "reference", + "multipicklist", + "combobox", + "encryptedstring", + "email", + "complexvalue", # TODO: Unverified + "masterrecord", + "datacategorygroupreference", + "base64", +} + +NUMBER_TYPES = {"double", "currency", "percent"} + +DATE_TYPES = {"datetime", "date"} + +BINARY_TYPES = {"byte"} + +LOOSE_TYPES = { + "anyType", # A calculated field's type can be any of the supported # formula data types (see https://developer.salesforce.com/docs/#i1435527) - 'calculated' -]) + "calculated", +} # The following objects are not supported by the bulk API. -UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS = set(['AssetTokenEvent', - 'AttachedContentNote', - 'EventWhoRelation', - 'QuoteTemplateRichTextData', - 'TaskWhoRelation', - 'SolutionStatus', - 'ContractStatus', - 'RecentlyViewed', - 'DeclinedEventRelation', - 'AcceptedEventRelation', - 'TaskStatus', - 'PartnerRole', - 'TaskPriority', - 'CaseStatus', - 'UndecidedEventRelation', - 'OrderStatus']) +UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS = { + "AssetTokenEvent", + "AttachedContentNote", + "EventWhoRelation", + "QuoteTemplateRichTextData", + "TaskWhoRelation", + "SolutionStatus", + "ContractStatus", + "RecentlyViewed", + "DeclinedEventRelation", + "AcceptedEventRelation", + "TaskStatus", + "PartnerRole", + "TaskPriority", + "CaseStatus", + "UndecidedEventRelation", + "OrderStatus", +} # The following objects have certain WHERE clause restrictions so we exclude them. -QUERY_RESTRICTED_SALESFORCE_OBJECTS = set(['Announcement', - 'ContentDocumentLink', - 'CollaborationGroupRecord', - 'Vote', - 'IdeaComment', - 'FieldDefinition', - 'PlatformAction', - 'UserEntityAccess', - 'RelationshipInfo', - 'ContentFolderMember', - 'ContentFolderItem', - 'SearchLayout', - 'SiteDetail', - 'EntityParticle', - 'OwnerChangeOptionInfo', - 'DataStatistics', - 'UserFieldAccess', - 'PicklistValueInfo', - 'RelationshipDomain', - 'FlexQueueItem', - 'NetworkUserHistoryRecent', - 'FieldHistoryArchive', - 'RecordActionHistory', - 'FlowVersionView', - 'FlowVariableView', - 'AppTabMember', - 'ColorDefinition', - 'IconDefinition',]) +QUERY_RESTRICTED_SALESFORCE_OBJECTS = { + "Announcement", + "ContentDocumentLink", + "CollaborationGroupRecord", + "Vote", + "IdeaComment", + "FieldDefinition", + "PlatformAction", + "UserEntityAccess", + "RelationshipInfo", + "ContentFolderMember", + "ContentFolderItem", + "SearchLayout", + "SiteDetail", + "EntityParticle", + "OwnerChangeOptionInfo", + "DataStatistics", + "UserFieldAccess", + "PicklistValueInfo", + "RelationshipDomain", + "FlexQueueItem", + "NetworkUserHistoryRecent", + "FieldHistoryArchive", + "RecordActionHistory", + "FlowVersionView", + "FlowVariableView", + "AppTabMember", + "ColorDefinition", + "IconDefinition", +} # The following objects are not supported by the query method being used. -QUERY_INCOMPATIBLE_SALESFORCE_OBJECTS = set(['DataType', - 'ListViewChartInstance', - 'FeedLike', - 'OutgoingEmail', - 'OutgoingEmailRelation', - 'FeedSignal', - 'ActivityHistory', - 'EmailStatus', - 'UserRecordAccess', - 'Name', - 'AggregateResult', - 'OpenActivity', - 'ProcessInstanceHistory', - 'OwnedContentDocument', - 'FolderedContentDocument', - 'FeedTrackedChange', - 'CombinedAttachment', - 'AttachedContentDocument', - 'ContentBody', - 'NoteAndAttachment', - 'LookedUpFromActivity', - 'AttachedContentNote', - 'QuoteTemplateRichTextData']) +QUERY_INCOMPATIBLE_SALESFORCE_OBJECTS = { + "DataType", + "ListViewChartInstance", + "FeedLike", + "OutgoingEmail", + "OutgoingEmailRelation", + "FeedSignal", + "ActivityHistory", + "EmailStatus", + "UserRecordAccess", + "Name", + "AggregateResult", + "OpenActivity", + "ProcessInstanceHistory", + "OwnedContentDocument", + "FolderedContentDocument", + "FeedTrackedChange", + "CombinedAttachment", + "AttachedContentDocument", + "ContentBody", + "NoteAndAttachment", + "LookedUpFromActivity", + "AttachedContentNote", + "QuoteTemplateRichTextData", +} + def log_backoff_attempt(details): LOGGER.info("ConnectionError detected, triggering backoff: %d try", details.get("tries")) @@ -153,37 +149,34 @@ def raise_for_status(resp): that this error is ephemeral and resolved after retries. """ if resp.status_code != 200: - err_msg = ( - f"{resp.status_code} Client Error: {resp.reason} " - f"for url: {resp.url}" - ) + err_msg = f"{resp.status_code} Client Error: {resp.reason} " f"for url: {resp.url}" LOGGER.warning(err_msg) - if resp.status_code == 406 and 'CustomNotAcceptable' in resp.reason: + if resp.status_code == 406 and "CustomNotAcceptable" in resp.reason: raise SFDCCustomNotAcceptableError(err_msg) else: resp.raise_for_status() -def field_to_property_schema(field, mdata): +def field_to_property_schema(field, mdata): # noqa: C901 property_schema = {} - field_name = field['name'] - sf_type = field['type'] + field_name = field["name"] + sf_type = field["type"] if sf_type in STRING_TYPES: - property_schema['type'] = "string" + property_schema["type"] = "string" elif sf_type in DATE_TYPES: date_type = {"type": "string", "format": "date-time"} string_type = {"type": ["string", "null"]} property_schema["anyOf"] = [date_type, string_type] elif sf_type == "boolean": - property_schema['type'] = "boolean" + property_schema["type"] = "boolean" elif sf_type in NUMBER_TYPES: - property_schema['type'] = "number" + property_schema["type"] = "number" elif sf_type == "address": - property_schema['type'] = "object" - property_schema['properties'] = { + property_schema["type"] = "object" + property_schema["properties"] = { "street": {"type": ["null", "string"]}, "state": {"type": ["null", "string"]}, "postalCode": {"type": ["null", "string"]}, @@ -191,59 +184,63 @@ def field_to_property_schema(field, mdata): "country": {"type": ["null", "string"]}, "longitude": {"type": ["null", "number"]}, "latitude": {"type": ["null", "number"]}, - "geocodeAccuracy": {"type": ["null", "string"]} + "geocodeAccuracy": {"type": ["null", "string"]}, } elif sf_type in ("int", "long"): - property_schema['type'] = "integer" + property_schema["type"] = "integer" elif sf_type == "time": - property_schema['type'] = "string" + property_schema["type"] = "string" elif sf_type in LOOSE_TYPES: return property_schema, mdata # No type = all types elif sf_type in BINARY_TYPES: - mdata = metadata.write(mdata, ('properties', field_name), "inclusion", "unsupported") - mdata = metadata.write(mdata, ('properties', field_name), - "unsupported-description", "binary data") + mdata = metadata.write(mdata, ("properties", field_name), "inclusion", "unsupported") + mdata = metadata.write(mdata, ("properties", field_name), "unsupported-description", "binary data") return property_schema, mdata - elif sf_type == 'location': + elif sf_type == "location": # geo coordinates are numbers or objects divided into two fields for lat/long - property_schema['type'] = ["number", "object", "null"] - property_schema['properties'] = { + property_schema["type"] = ["number", "object", "null"] + property_schema["properties"] = { "longitude": {"type": ["null", "number"]}, - "latitude": {"type": ["null", "number"]} + "latitude": {"type": ["null", "number"]}, } - elif sf_type == 'json': - property_schema['type'] = "string" + elif sf_type == "json": + property_schema["type"] = "string" else: - raise TapSalesforceException("Found unsupported type: {}".format(sf_type)) + raise TapSalesforceExceptionError(f"Found unsupported type: {sf_type}") # The nillable field cannot be trusted - if field_name != 'Id' and sf_type != 'location' and sf_type not in DATE_TYPES: - property_schema['type'] = ["null", property_schema['type']] + if field_name != "Id" and sf_type != "location" and sf_type not in DATE_TYPES: + property_schema["type"] = ["null", property_schema["type"]] return property_schema, mdata -class Salesforce(): + +class Salesforce: # pylint: disable=too-many-instance-attributes,too-many-arguments - def __init__(self, - credentials=None, - token=None, - quota_percent_per_run=None, - quota_percent_total=None, - is_sandbox=None, - select_fields_by_default=None, - default_start_date=None, - api_type=None): + def __init__( + self, + credentials=None, + token=None, + quota_percent_per_run=None, + quota_percent_total=None, + is_sandbox=None, + select_fields_by_default=None, + default_start_date=None, + api_type=None, + ): self.api_type = api_type.upper() if api_type else None self.session = requests.Session() - if isinstance(quota_percent_per_run, str) and quota_percent_per_run.strip() == '': + if isinstance(quota_percent_per_run, str) and quota_percent_per_run.strip() == "": quota_percent_per_run = None - if isinstance(quota_percent_total, str) and quota_percent_total.strip() == '': + if isinstance(quota_percent_total, str) and quota_percent_total.strip() == "": quota_percent_total = None self.quota_percent_per_run = float(quota_percent_per_run) if quota_percent_per_run is not None else 25 self.quota_percent_total = float(quota_percent_total) if quota_percent_total is not None else 80 - self.is_sandbox = is_sandbox is True or (isinstance(is_sandbox, str) and is_sandbox.lower() == 'true') - self.select_fields_by_default = select_fields_by_default is True or (isinstance(select_fields_by_default, str) and select_fields_by_default.lower() == 'true') + self.is_sandbox = is_sandbox is True or (isinstance(is_sandbox, str) and is_sandbox.lower() == "true") + self.select_fields_by_default = select_fields_by_default is True or ( + isinstance(select_fields_by_default, str) and select_fields_by_default.lower() == "true" + ) self.rest_requests_attempted = 0 self.jobs_completed = 0 self.data_url = "{}/services/data/v60.0/{}" @@ -259,11 +256,15 @@ def __init__(self, ).isoformat() if default_start_date: - LOGGER.info("Parsed start date '%s' from value '%s'", self.default_start_date, default_start_date) + LOGGER.info( + "Parsed start date '%s' from value '%s'", + self.default_start_date, + default_start_date, + ) # pylint: disable=anomalous-backslash-in-string,line-too-long def check_rest_quota_usage(self, headers): - match = re.search('^api-usage=(\d+)/(\d+)$', headers.get('Sforce-Limit-Info')) + match = re.search(r"^api-usage=(\d+)/(\d+)$", headers.get("Sforce-Limit-Info")) if match is None: return @@ -276,21 +277,24 @@ def check_rest_quota_usage(self, headers): max_requests_for_run = int((self.quota_percent_per_run * allotted) / 100) if percent_used_from_total > self.quota_percent_total: - total_message = ("Salesforce has reported {}/{} ({:3.2f}%) total REST quota " + - "used across all Salesforce Applications. Terminating " + - "replication to not continue past configured percentage " + - "of {}% total quota.").format(remaining, - allotted, - percent_used_from_total, - self.quota_percent_total) - raise TapSalesforceQuotaExceededException(total_message) + total_message = ( + "Salesforce has reported {}/{} ({:3.2f}%) total REST quota " + + "used across all Salesforce Applications. Terminating " + + "replication to not continue past configured percentage " + + "of {}% total quota." + ).format(remaining, allotted, percent_used_from_total, self.quota_percent_total) + raise TapSalesforceQuotaExceededError(total_message) elif self.rest_requests_attempted > max_requests_for_run: - partial_message = ("This replication job has made {} REST requests ({:3.2f}% of " + - "total quota). Terminating replication due to allotted " + - "quota of {}% per replication.").format(self.rest_requests_attempted, - (self.rest_requests_attempted / allotted) * 100, - self.quota_percent_per_run) - raise TapSalesforceQuotaExceededException(partial_message) + partial_message = ( + "This replication job has made {} REST requests ({:3.2f}% of " + + "total quota). Terminating replication due to allotted " + + "quota of {}% per replication." + ).format( + self.rest_requests_attempted, + (self.rest_requests_attempted / allotted) * 100, + self.quota_percent_per_run, + ) + raise TapSalesforceQuotaExceededError(partial_message) def login(self): self.auth.login() @@ -300,11 +304,13 @@ def instance_url(self): return self.auth.instance_url # pylint: disable=too-many-arguments - @backoff.on_exception(backoff.expo, - (requests.exceptions.ConnectionError, SFDCCustomNotAcceptableError), - max_tries=10, - factor=2, - on_backoff=log_backoff_attempt) + @backoff.on_exception( + backoff.expo, + (requests.exceptions.ConnectionError, SFDCCustomNotAcceptableError), + max_tries=10, + factor=2, + on_backoff=log_backoff_attempt, + ) def _make_request(self, http_method, url, headers=None, body=None, stream=False, params=None): if http_method == "GET": LOGGER.info("Making %s request to %s with params: %s", http_method, url, params) @@ -313,11 +319,11 @@ def _make_request(self, http_method, url, headers=None, body=None, stream=False, LOGGER.info("Making %s request to %s with body %s", http_method, url, body) resp = self.session.post(url, headers=headers, data=body) else: - raise TapSalesforceException("Unsupported HTTP method") + raise TapSalesforceExceptionError("Unsupported HTTP method") raise_for_status(resp) - if resp.headers.get('Sforce-Limit-Info') is not None: + if resp.headers.get("Sforce-Limit-Info") is not None: self.rest_requests_attempted += 1 self.check_rest_quota_usage(resp.headers) @@ -332,53 +338,50 @@ def describe(self, sobject=None): endpoint_tag = "sobjects" url = self.data_url.format(instance_url, endpoint) else: - endpoint = "sobjects/{}/describe".format(sobject) + endpoint = f"sobjects/{sobject}/describe" endpoint_tag = sobject url = self.data_url.format(instance_url, endpoint) with metrics.http_request_timer("describe") as timer: - timer.tags['endpoint'] = endpoint_tag - resp = self._make_request('GET', url, headers=headers) + timer.tags["endpoint"] = endpoint_tag + resp = self._make_request("GET", url, headers=headers) return resp.json() # pylint: disable=no-self-use def _get_selected_properties(self, catalog_entry): - mdata = metadata.to_map(catalog_entry['metadata']) - properties = catalog_entry['schema'].get('properties', {}) - - return [k for k in properties.keys() - if singer.should_sync_field(metadata.get(mdata, ('properties', k), 'inclusion'), - metadata.get(mdata, ('properties', k), 'selected'), - self.select_fields_by_default)] - + mdata = metadata.to_map(catalog_entry["metadata"]) + properties = catalog_entry["schema"].get("properties", {}) + + return [ + k + for k in properties + if singer.should_sync_field( + metadata.get(mdata, ("properties", k), "inclusion"), + metadata.get(mdata, ("properties", k), "selected"), + self.select_fields_by_default, + ) + ] def get_start_date(self, state, catalog_entry): - catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + catalog_metadata = metadata.to_map(catalog_entry["metadata"]) + replication_key = catalog_metadata.get((), {}).get("replication-key") - return (singer.get_bookmark(state, - catalog_entry['tap_stream_id'], - replication_key) or self.default_start_date) + return singer.get_bookmark(state, catalog_entry["tap_stream_id"], replication_key) or self.default_start_date def _build_query_string(self, catalog_entry, start_date, end_date=None, order_by_clause=True): selected_properties = self._get_selected_properties(catalog_entry) - query = "SELECT {} FROM {}".format(",".join(selected_properties), catalog_entry['stream']) + query = "SELECT {} FROM {}".format(",".join(selected_properties), catalog_entry["stream"]) - catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + catalog_metadata = metadata.to_map(catalog_entry["metadata"]) + replication_key = catalog_metadata.get((), {}).get("replication-key") if replication_key: - where_clause = " WHERE {} >= {} ".format( - replication_key, - start_date) - if end_date: - end_date_clause = " AND {} < {}".format(replication_key, end_date) - else: - end_date_clause = "" - - order_by = " ORDER BY {} ASC".format(replication_key) + where_clause = f" WHERE {replication_key} >= {start_date} " + end_date_clause = f" AND {replication_key} < {end_date}" if end_date else "" + + order_by = f" ORDER BY {replication_key} ASC" if order_by_clause: return query + where_clause + end_date_clause + order_by @@ -397,28 +400,28 @@ def query(self, catalog_entry, state): rest = Rest(self) return rest.query(catalog_entry, state) else: - raise TapSalesforceException( - "api_type should be REST or BULK was: {}".format( - self.api_type)) + raise TapSalesforceExceptionError(f"api_type should be REST or BULK was: {self.api_type}") def get_blacklisted_objects(self): if self.api_type in [BULK_API_TYPE, BULK2_API_TYPE]: - return UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS.union( - QUERY_RESTRICTED_SALESFORCE_OBJECTS).union(QUERY_INCOMPATIBLE_SALESFORCE_OBJECTS) + return UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS.union(QUERY_RESTRICTED_SALESFORCE_OBJECTS).union( + QUERY_INCOMPATIBLE_SALESFORCE_OBJECTS + ) elif self.api_type == REST_API_TYPE: return QUERY_RESTRICTED_SALESFORCE_OBJECTS.union(QUERY_INCOMPATIBLE_SALESFORCE_OBJECTS) else: - raise TapSalesforceException( - "api_type should be REST or BULK was: {}".format( - self.api_type)) + raise TapSalesforceExceptionError(f"api_type should be REST or BULK was: {self.api_type}") # pylint: disable=line-too-long def get_blacklisted_fields(self): if self.api_type == BULK_API_TYPE or self.api_type == BULK2_API_TYPE: - return {('EntityDefinition', 'RecordTypesSupported'): "this field is unsupported by the Bulk API."} + return { + ( + "EntityDefinition", + "RecordTypesSupported", + ): "this field is unsupported by the Bulk API." + } elif self.api_type == REST_API_TYPE: return {} else: - raise TapSalesforceException( - "api_type should be REST or BULK was: {}".format( - self.api_type)) + raise TapSalesforceExceptionError(f"api_type should be REST or BULK was: {self.api_type}") diff --git a/tap_salesforce/salesforce/bulk.py b/tap_salesforce/salesforce/bulk.py index 15414c7..24b7429 100644 --- a/tap_salesforce/salesforce/bulk.py +++ b/tap_salesforce/salesforce/bulk.py @@ -2,16 +2,18 @@ import csv import json import sys -import time import tempfile -import singer -from singer import metrics -from requests.exceptions import RequestException +import time +import singer import xmltodict +from requests.exceptions import RequestException +from singer import metrics from tap_salesforce.salesforce.exceptions import ( - TapSalesforceException, TapSalesforceQuotaExceededException) + TapSalesforceExceptionError, + TapSalesforceQuotaExceededError, +) BATCH_STATUS_POLLING_SLEEP = 20 PK_CHUNKED_BATCH_STATUS_POLLING_SLEEP = 60 @@ -20,25 +22,25 @@ LOGGER = singer.get_logger() + # pylint: disable=inconsistent-return-statements def find_parent(stream): parent_stream = stream if stream.endswith("CleanInfo"): - parent_stream = stream[:stream.find("CleanInfo")] + parent_stream = stream[: stream.find("CleanInfo")] elif stream.endswith("FieldHistory"): - parent_stream = stream[:stream.find("FieldHistory")] + parent_stream = stream[: stream.find("FieldHistory")] elif stream.endswith("History"): - parent_stream = stream[:stream.find("History")] + parent_stream = stream[: stream.find("History")] # If the stripped stream ends with "__" we can assume the parent is a custom table if parent_stream.endswith("__"): - parent_stream += 'c' + parent_stream += "c" return parent_stream -class Bulk(): - +class Bulk: bulk_url = "{}/services/async/60.0/{}" def __init__(self, sf): @@ -49,8 +51,7 @@ def __init__(self, sf): def query(self, catalog_entry, state): self.check_bulk_quota_usage() - for record in self._bulk_query(catalog_entry, state): - yield record + yield from self._bulk_query(catalog_entry, state) self.sf.jobs_completed += 1 @@ -60,30 +61,38 @@ def check_bulk_quota_usage(self): url = self.sf.data_url.format(self.sf.instance_url, endpoint) with metrics.http_request_timer(endpoint): - resp = self.sf._make_request('GET', url, headers=self.sf.auth.rest_headers).json() + resp = self.sf._make_request("GET", url, headers=self.sf.auth.rest_headers).json() - quota_max = resp['DailyBulkApiBatches']['Max'] + quota_max = resp["DailyBulkApiBatches"]["Max"] max_requests_for_run = int((self.sf.quota_percent_per_run * quota_max) / 100) - quota_remaining = resp['DailyBulkApiBatches']['Remaining'] + quota_remaining = resp["DailyBulkApiBatches"]["Remaining"] percent_used = (1 - (quota_remaining / quota_max)) * 100 if percent_used > self.sf.quota_percent_total: - total_message = ("Salesforce has reported {}/{} ({:3.2f}%) total Bulk API quota " + - "used across all Salesforce Applications. Terminating " + - "replication to not continue past configured percentage " + - "of {}% total quota.").format(quota_max - quota_remaining, - quota_max, - percent_used, - self.sf.quota_percent_total) - raise TapSalesforceQuotaExceededException(total_message) + total_message = ( + "Salesforce has reported {}/{} ({:3.2f}%) total Bulk API quota " + + "used across all Salesforce Applications. Terminating " + + "replication to not continue past configured percentage " + + "of {}% total quota." + ).format( + quota_max - quota_remaining, + quota_max, + percent_used, + self.sf.quota_percent_total, + ) + raise TapSalesforceQuotaExceededError(total_message) elif self.sf.jobs_completed > max_requests_for_run: - partial_message = ("This replication job has completed {} Bulk API jobs ({:3.2f}% of " + - "total quota). Terminating replication due to allotted " + - "quota of {}% per replication.").format(self.sf.jobs_completed, - (self.sf.jobs_completed / quota_max) * 100, - self.sf.quota_percent_per_run) - raise TapSalesforceQuotaExceededException(partial_message) + partial_message = ( + "This replication job has completed {} Bulk API jobs ({:3.2f}% of " + + "total quota). Terminating replication due to allotted " + + "quota of {}% per replication." + ).format( + self.sf.jobs_completed, + (self.sf.jobs_completed / quota_max) * 100, + self.sf.quota_percent_per_run, + ) + raise TapSalesforceQuotaExceededError(partial_message) def _get_bulk_headers(self): return {**self.sf.auth.bulk_headers, "Content-Type": "application/json"} @@ -98,29 +107,35 @@ def _bulk_query(self, catalog_entry, state): batch_status = self._poll_on_batch_status(job_id, batch_id) - if batch_status['state'] == 'Failed': - if "QUERY_TIMEOUT" in batch_status['stateMessage']: + if batch_status["state"] == "Failed": + if "QUERY_TIMEOUT" in batch_status["stateMessage"]: batch_status = self._bulk_query_with_pk_chunking(catalog_entry, start_date) - job_id = batch_status['job_id'] + job_id = batch_status["job_id"] # Set pk_chunking to True to indicate that we should write a bookmark differently self.sf.pk_chunking = True # Add the bulk Job ID and its batches to the state so it can be resumed if necessary - tap_stream_id = catalog_entry['tap_stream_id'] - state = singer.write_bookmark(state, tap_stream_id, 'JobID', job_id) - state = singer.write_bookmark(state, tap_stream_id, 'BatchIDs', batch_status['completed'][:]) + tap_stream_id = catalog_entry["tap_stream_id"] + state = singer.write_bookmark(state, tap_stream_id, "JobID", job_id) + state = singer.write_bookmark(state, tap_stream_id, "BatchIDs", batch_status["completed"][:]) - for completed_batch_id in batch_status['completed']: + for completed_batch_id in batch_status["completed"]: for result in self.get_batch_results(job_id, completed_batch_id, catalog_entry): yield result # Remove the completed batch ID and write state - state['bookmarks'][catalog_entry['tap_stream_id']]["BatchIDs"].remove(completed_batch_id) - LOGGER.info("Finished syncing batch %s. Removing batch from state.", completed_batch_id) - LOGGER.info("Batches to go: %d", len(state['bookmarks'][catalog_entry['tap_stream_id']]["BatchIDs"])) + state["bookmarks"][catalog_entry["tap_stream_id"]]["BatchIDs"].remove(completed_batch_id) + LOGGER.info( + "Finished syncing batch %s. Removing batch from state.", + completed_batch_id, + ) + LOGGER.info( + "Batches to go: %d", + len(state["bookmarks"][catalog_entry["tap_stream_id"]]["BatchIDs"]), + ) singer.write_state(state) else: - raise TapSalesforceException(batch_status['stateMessage']) + raise TapSalesforceExceptionError(batch_status["stateMessage"]) else: for result in self.get_batch_results(job_id, batch_id, catalog_entry): yield result @@ -134,10 +149,10 @@ def _bulk_query_with_pk_chunking(self, catalog_entry, start_date): self._add_batch(catalog_entry, job_id, start_date, False) batch_status = self._poll_on_pk_chunked_batch_status(job_id) - batch_status['job_id'] = job_id + batch_status["job_id"] = job_id - if batch_status['failed']: - raise TapSalesforceException("One or more batches failed during PK chunked job") + if batch_status["failed"]: + raise TapSalesforceExceptionError("One or more batches failed during PK chunked job") # Close the job after all the batches are complete self._close_job(job_id) @@ -146,154 +161,148 @@ def _bulk_query_with_pk_chunking(self, catalog_entry, start_date): def _create_job(self, catalog_entry, pk_chunking=False): url = self.bulk_url.format(self.sf.instance_url, "job") - body = {"operation": "queryAll", "object": catalog_entry['stream'], "contentType": "CSV"} + body = { + "operation": "queryAll", + "object": catalog_entry["stream"], + "contentType": "CSV", + } headers = self._get_bulk_headers() - headers['Sforce-Disable-Batch-Retry'] = "true" + headers["Sforce-Disable-Batch-Retry"] = "true" if pk_chunking: LOGGER.info("ADDING PK CHUNKING HEADER") - headers['Sforce-Enable-PKChunking'] = "true; chunkSize={}".format(DEFAULT_CHUNK_SIZE) + headers["Sforce-Enable-PKChunking"] = f"true; chunkSize={DEFAULT_CHUNK_SIZE}" # If the stream ends with 'CleanInfo' or 'History', we can PK Chunk on the object's parent - if any(catalog_entry['stream'].endswith(suffix) for suffix in ["CleanInfo", "History"]): - parent = find_parent(catalog_entry['stream']) - headers['Sforce-Enable-PKChunking'] = headers['Sforce-Enable-PKChunking'] + "; parent={}".format(parent) + if any(catalog_entry["stream"].endswith(suffix) for suffix in ["CleanInfo", "History"]): + parent = find_parent(catalog_entry["stream"]) + headers["Sforce-Enable-PKChunking"] = headers["Sforce-Enable-PKChunking"] + f"; parent={parent}" with metrics.http_request_timer("create_job") as timer: - timer.tags['sobject'] = catalog_entry['stream'] - resp = self.sf._make_request( - 'POST', - url, - headers=headers, - body=json.dumps(body)) + timer.tags["sobject"] = catalog_entry["stream"] + resp = self.sf._make_request("POST", url, headers=headers, body=json.dumps(body)) job = resp.json() - return job['id'] + return job["id"] def _add_batch(self, catalog_entry, job_id, start_date, order_by_clause=True): - endpoint = "job/{}/batch".format(job_id) + endpoint = f"job/{job_id}/batch" url = self.bulk_url.format(self.sf.instance_url, endpoint) body = self.sf._build_query_string(catalog_entry, start_date, order_by_clause=order_by_clause) headers = self._get_bulk_headers() - headers['Content-Type'] = 'text/csv' + headers["Content-Type"] = "text/csv" with metrics.http_request_timer("add_batch") as timer: - timer.tags['sobject'] = catalog_entry['stream'] - resp = self.sf._make_request('POST', url, headers=headers, body=body) + timer.tags["sobject"] = catalog_entry["stream"] + resp = self.sf._make_request("POST", url, headers=headers, body=body) batch = xmltodict.parse(resp.text) - return batch['batchInfo']['id'] + return batch["batchInfo"]["id"] def _poll_on_pk_chunked_batch_status(self, job_id): batches = self._get_batches(job_id) while True: - queued_batches = [b['id'] for b in batches if b['state'] == "Queued"] - in_progress_batches = [b['id'] for b in batches if b['state'] == "InProgress"] + queued_batches = [b["id"] for b in batches if b["state"] == "Queued"] + in_progress_batches = [b["id"] for b in batches if b["state"] == "InProgress"] if not queued_batches and not in_progress_batches: - completed_batches = [b['id'] for b in batches if b['state'] == "Completed"] - failed_batches = [b['id'] for b in batches if b['state'] == "Failed"] - return {'completed': completed_batches, 'failed': failed_batches} + completed_batches = [b["id"] for b in batches if b["state"] == "Completed"] + failed_batches = [b["id"] for b in batches if b["state"] == "Failed"] + return {"completed": completed_batches, "failed": failed_batches} else: time.sleep(PK_CHUNKED_BATCH_STATUS_POLLING_SLEEP) batches = self._get_batches(job_id) def _poll_on_batch_status(self, job_id, batch_id): - batch_status = self._get_batch(job_id=job_id, - batch_id=batch_id) + batch_status = self._get_batch(job_id=job_id, batch_id=batch_id) - while batch_status['state'] not in ['Completed', 'Failed', 'Not Processed']: + while batch_status["state"] not in ["Completed", "Failed", "Not Processed"]: time.sleep(BATCH_STATUS_POLLING_SLEEP) - batch_status = self._get_batch(job_id=job_id, - batch_id=batch_id) + batch_status = self._get_batch(job_id=job_id, batch_id=batch_id) return batch_status def job_exists(self, job_id): try: - endpoint = "job/{}".format(job_id) + endpoint = f"job/{job_id}" url = self.bulk_url.format(self.sf.instance_url, endpoint) headers = self._get_bulk_headers() with metrics.http_request_timer("get_job"): - self.sf._make_request('GET', url, headers=headers) + self.sf._make_request("GET", url, headers=headers) - return True # requests will raise for a 400 InvalidJob + return True # requests will raise for a 400 InvalidJob except RequestException as ex: - if ex.response.headers["Content-Type"] == 'application/json': - exception_code = ex.response.json()['exceptionCode'] - if exception_code == 'InvalidJob': + if ex.response.headers["Content-Type"] == "application/json": + exception_code = ex.response.json()["exceptionCode"] + if exception_code == "InvalidJob": return False raise def _get_batches(self, job_id): - endpoint = "job/{}/batch".format(job_id) + endpoint = f"job/{job_id}/batch" url = self.bulk_url.format(self.sf.instance_url, endpoint) headers = self._get_bulk_headers() with metrics.http_request_timer("get_batches"): - resp = self.sf._make_request('GET', url, headers=headers) + resp = self.sf._make_request("GET", url, headers=headers) - batches = xmltodict.parse(resp.text, - xml_attribs=False, - force_list=('batchInfo',))['batchInfoList']['batchInfo'] + batches = xmltodict.parse(resp.text, xml_attribs=False, force_list=("batchInfo",))["batchInfoList"]["batchInfo"] return batches def _get_batch(self, job_id, batch_id): - endpoint = "job/{}/batch/{}".format(job_id, batch_id) + endpoint = f"job/{job_id}/batch/{batch_id}" url = self.bulk_url.format(self.sf.instance_url, endpoint) headers = self._get_bulk_headers() with metrics.http_request_timer("get_batch"): - resp = self.sf._make_request('GET', url, headers=headers) + resp = self.sf._make_request("GET", url, headers=headers) batch = xmltodict.parse(resp.text) - return batch['batchInfo'] + return batch["batchInfo"] def get_batch_results(self, job_id, batch_id, catalog_entry): """Given a job_id and batch_id, queries the batches results and reads CSV lines yielding each line as a record.""" headers = self._get_bulk_headers() - endpoint = "job/{}/batch/{}/result".format(job_id, batch_id) + endpoint = f"job/{job_id}/batch/{batch_id}/result" url = self.bulk_url.format(self.sf.instance_url, endpoint) with metrics.http_request_timer("batch_result_list") as timer: - timer.tags['sobject'] = catalog_entry['stream'] - batch_result_resp = self.sf._make_request('GET', url, headers=headers) + timer.tags["sobject"] = catalog_entry["stream"] + batch_result_resp = self.sf._make_request("GET", url, headers=headers) # Returns a Dict where input: # 12 # will return: {'result', ['1', '2']} - batch_result_list = xmltodict.parse(batch_result_resp.text, - xml_attribs=False, - force_list={'result'})['result-list'] + batch_result_list = xmltodict.parse(batch_result_resp.text, xml_attribs=False, force_list={"result"})[ + "result-list" + ] - for result in batch_result_list['result']: - endpoint = "job/{}/batch/{}/result/{}".format(job_id, batch_id, result) + for result in batch_result_list["result"]: + endpoint = f"job/{job_id}/batch/{batch_id}/result/{result}" url = self.bulk_url.format(self.sf.instance_url, endpoint) - headers['Content-Type'] = 'text/csv' + headers["Content-Type"] = "text/csv" with tempfile.NamedTemporaryFile(mode="w+", encoding="utf8") as csv_file: - resp = self.sf._make_request('GET', url, headers=headers, stream=True) + resp = self.sf._make_request("GET", url, headers=headers, stream=True) for chunk in resp.iter_content(chunk_size=ITER_CHUNK_SIZE, decode_unicode=True): if chunk: # Replace any NULL bytes in the chunk so it can be safely given to the CSV reader - csv_file.write(chunk.replace('\0', '')) + csv_file.write(chunk.replace("\0", "")) csv_file.seek(0) - csv_reader = csv.reader(csv_file, - delimiter=',', - quotechar='"') + csv_reader = csv.reader(csv_file, delimiter=",", quotechar='"') column_name_list = next(csv_reader) @@ -302,16 +311,12 @@ def get_batch_results(self, job_id, batch_id, catalog_entry): yield rec def _close_job(self, job_id): - endpoint = "job/{}".format(job_id) + endpoint = f"job/{job_id}" url = self.bulk_url.format(self.sf.instance_url, endpoint) body = {"state": "Closed"} with metrics.http_request_timer("close_job"): - self.sf._make_request( - 'POST', - url, - headers=self._get_bulk_headers(), - body=json.dumps(body)) + self.sf._make_request("POST", url, headers=self._get_bulk_headers(), body=json.dumps(body)) # pylint: disable=no-self-use def _iter_lines(self, response): @@ -326,13 +331,9 @@ def _iter_lines(self, response): lines = chunk.splitlines(keepends=True) - if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: - pending = lines.pop() - else: - pending = None + pending = lines.pop() if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1] else None - for line in lines: - yield line + yield from lines if pending is not None: yield pending diff --git a/tap_salesforce/salesforce/bulk2.py b/tap_salesforce/salesforce/bulk2.py index 6f1ba22..330ec8e 100644 --- a/tap_salesforce/salesforce/bulk2.py +++ b/tap_salesforce/salesforce/bulk2.py @@ -1,34 +1,30 @@ -import time import csv -import sys import json +import sys +import time + import singer from singer import metrics - BATCH_STATUS_POLLING_SLEEP = 20 DEFAULT_CHUNK_SIZE = 50000 LOGGER = singer.get_logger() -class Bulk2(): - bulk_url = '{}/services/data/v60.0/jobs/query' + +class Bulk2: + bulk_url = "{}/services/data/v60.0/jobs/query" def __init__(self, sf): csv.field_size_limit(sys.maxsize) self.sf = sf - def query(self, catalog_entry, state): job_id = self._create_job(catalog_entry, state) self._wait_for_job(job_id) for batch in self._get_next_batch(job_id): - reader = csv.DictReader(batch.decode('utf-8').splitlines()) - - for row in reader: - yield row - + yield from csv.DictReader(batch.decode("utf-8").splitlines()) def _get_bulk_headers(self): return {**self.sf.auth.rest_headers, "Content-Type": "application/json"} @@ -45,46 +41,42 @@ def _create_job(self, catalog_entry, state): } with metrics.http_request_timer("create_job") as timer: - timer.tags['sobject'] = catalog_entry['stream'] - resp = self.sf._make_request( - 'POST', - url, - headers=self._get_bulk_headers(), - body=json.dumps(body)) + timer.tags["sobject"] = catalog_entry["stream"] + resp = self.sf._make_request("POST", url, headers=self._get_bulk_headers(), body=json.dumps(body)) job = resp.json() - return job['id'] + return job["id"] def _wait_for_job(self, job_id): - status_url = self.bulk_url + '/{}' + status_url = self.bulk_url + "/{}" url = status_url.format(self.sf.instance_url, job_id) status = None - while status not in ('JobComplete', 'Failed'): - resp = self.sf._make_request('GET', url, headers=self._get_bulk_headers()).json() - status = resp['state'] + while status not in ("JobComplete", "Failed"): + resp = self.sf._make_request("GET", url, headers=self._get_bulk_headers()).json() + status = resp["state"] - if status == 'JobComplete': + if status == "JobComplete": break - if status == 'Failed': - raise Exception("Job failed: {}".format(resp.json())) + if status == "Failed": + raise Exception(f"Job failed: {resp.json()}") time.sleep(BATCH_STATUS_POLLING_SLEEP) def _get_next_batch(self, job_id): - url = self.bulk_url + '/{}/results' + url = self.bulk_url + "/{}/results" url = url.format(self.sf.instance_url, job_id) - locator = '' + locator = "" - while locator != 'null': + while locator != "null": params = {"maxRecords": DEFAULT_CHUNK_SIZE} - if locator != '': - params['locator'] = locator + if locator != "": + params["locator"] = locator - resp = self.sf._make_request('GET', url, headers=self._get_bulk_headers(), params=params) - locator = resp.headers.get('Sforce-Locator') + resp = self.sf._make_request("GET", url, headers=self._get_bulk_headers(), params=params) + locator = resp.headers.get("Sforce-Locator") yield resp.content diff --git a/tap_salesforce/salesforce/credentials.py b/tap_salesforce/salesforce/credentials.py index dd80344..9c9add5 100644 --- a/tap_salesforce/salesforce/credentials.py +++ b/tap_salesforce/salesforce/credentials.py @@ -1,24 +1,16 @@ -import threading import logging -import requests +import threading from collections import namedtuple -from simple_salesforce import SalesforceLogin +import requests +from simple_salesforce import SalesforceLogin LOGGER = logging.getLogger(__name__) -OAuthCredentials = namedtuple('OAuthCredentials', ( - "client_id", - "client_secret", - "refresh_token" -)) +OAuthCredentials = namedtuple("OAuthCredentials", ("client_id", "client_secret", "refresh_token")) -PasswordCredentials = namedtuple('PasswordCredentials', ( - "username", - "password", - "security_token" -)) +PasswordCredentials = namedtuple("PasswordCredentials", ("username", "password", "security_token")) def parse_credentials(config): @@ -30,7 +22,7 @@ def parse_credentials(config): raise Exception("Cannot create credentials from config.") -class SalesforceAuth(): +class SalesforceAuth: def __init__(self, credentials, is_sandbox=False): self.is_sandbox = is_sandbox self._credentials = credentials @@ -41,16 +33,17 @@ def __init__(self, credentials, is_sandbox=False): def login(self): """Attempt to login and set the `instance_url` and `access_token` on success.""" - pass @property def rest_headers(self): - return {"Authorization": "Bearer {}".format(self._access_token)} + return {"Authorization": f"Bearer {self._access_token}"} @property def bulk_headers(self): - return {"X-SFDC-Session": self._access_token, - "Content-Type": "application/json"} + return { + "X-SFDC-Session": self._access_token, + "Content-Type": "application/json", + } @property def instance_url(self): @@ -73,14 +66,14 @@ class SalesforceAuthOAuth(SalesforceAuth): @property def _login_body(self): - return {'grant_type': 'refresh_token', **self._credentials._asdict()} + return {"grant_type": "refresh_token", **self._credentials._asdict()} @property def _login_url(self): - login_url = 'https://login.salesforce.com/services/oauth2/token' + login_url = "https://login.salesforce.com/services/oauth2/token" if self.is_sandbox: - login_url = 'https://test.salesforce.com/services/oauth2/token' + login_url = "https://test.salesforce.com/services/oauth2/token" return login_url @@ -88,20 +81,22 @@ def login(self): try: LOGGER.info("Attempting login via OAuth2") - resp = requests.post(self._login_url, - data=self._login_body, - headers={"Content-Type": "application/x-www-form-urlencoded"}) + resp = requests.post( + self._login_url, + data=self._login_body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) resp.raise_for_status() auth = resp.json() LOGGER.info("OAuth2 login successful") - self._access_token = auth['access_token'] - self._instance_url = auth['instance_url'] + self._access_token = auth["access_token"] + self._instance_url = auth["instance_url"] except Exception as e: error_message = str(e) if resp: - error_message = error_message + ", Response from Salesforce: {}".format(resp.text) + error_message = error_message + f", Response from Salesforce: {resp.text}" raise Exception(error_message) from e finally: LOGGER.info("Starting new login timer") @@ -111,10 +106,7 @@ def login(self): class SalesforceAuthPassword(SalesforceAuth): def login(self): - login = SalesforceLogin( - sandbox=self.is_sandbox, - **self._credentials._asdict() - ) + login = SalesforceLogin(sandbox=self.is_sandbox, **self._credentials._asdict()) self._access_token, host = login self._instance_url = "https://" + host diff --git a/tap_salesforce/salesforce/exceptions.py b/tap_salesforce/salesforce/exceptions.py index bd333ff..f0f90dd 100644 --- a/tap_salesforce/salesforce/exceptions.py +++ b/tap_salesforce/salesforce/exceptions.py @@ -1,10 +1,11 @@ # pylint: disable=super-init-not-called -class TapSalesforceException(Exception): + +class TapSalesforceExceptionError(Exception): pass -class TapSalesforceQuotaExceededException(TapSalesforceException): +class TapSalesforceQuotaExceededError(TapSalesforceExceptionError): pass @@ -17,6 +18,6 @@ class SFDCCustomNotAcceptableError(Exception): on any salesforce documentation page or forum. Example Error Message: ``` - requests.exceptions.HTTPError: 406 Client Error: CustomNotAcceptable for + requests.exceptions.HTTPError: 406 Client Error: CustomNotAcceptable for url: https://XXX.salesforce.com/services/data/v53.0/sobjects/XXX/describe """ diff --git a/tap_salesforce/salesforce/rest.py b/tap_salesforce/salesforce/rest.py index 297869d..6ab83f0 100644 --- a/tap_salesforce/salesforce/rest.py +++ b/tap_salesforce/salesforce/rest.py @@ -2,14 +2,15 @@ import singer import singer.utils as singer_utils from requests.exceptions import HTTPError -from tap_salesforce.salesforce.exceptions import TapSalesforceException + +from tap_salesforce.salesforce.exceptions import TapSalesforceExceptionError LOGGER = singer.get_logger() MAX_RETRIES = 4 -class Rest(): +class Rest: def __init__(self, sf): self.sf = sf @@ -20,15 +21,9 @@ def query(self, catalog_entry, state): return self._query_recur(query, catalog_entry, start_date) # pylint: disable=too-many-arguments - def _query_recur( - self, - query, - catalog_entry, - start_date_str, - end_date=None, - retries=MAX_RETRIES): + def _query_recur(self, query, catalog_entry, start_date_str, end_date=None, retries=MAX_RETRIES): params = {"q": query} - url = "{}/services/data/v60.0/queryAll".format(self.sf.instance_url) + url = f"{self.sf.instance_url}/services/data/v60.0/queryAll" headers = self.sf.auth.rest_headers sync_start = singer_utils.now() @@ -36,25 +31,20 @@ def _query_recur( end_date = sync_start if retries == 0: - raise TapSalesforceException( - "Ran out of retries attempting to query Salesforce Object {}".format( - catalog_entry['stream'])) + raise TapSalesforceExceptionError( + "Ran out of retries attempting to query Salesforce Object {}".format(catalog_entry["stream"]) + ) retryable = False try: - for rec in self._sync_records(url, headers, params): - yield rec + yield from self._sync_records(url, headers, params) # If the date range was chunked (an end_date was passed), sync # from the end_date -> now if end_date < sync_start: next_start_date_str = singer_utils.strftime(end_date) query = self.sf._build_query_string(catalog_entry, next_start_date_str) - for record in self._query_recur( - query, - catalog_entry, - next_start_date_str, - retries=retries): + for record in self._query_recur(query, catalog_entry, next_start_date_str, retries=retries): yield record except HTTPError as ex: @@ -65,7 +55,8 @@ def _query_recur( LOGGER.info( "Salesforce returned QUERY_TIMEOUT querying %d days of %s", day_range, - catalog_entry['stream']) + catalog_entry["stream"], + ) retryable = True else: raise ex @@ -76,30 +67,28 @@ def _query_recur( end_date = end_date - half_day_range if half_day_range.days == 0: - raise TapSalesforceException( - "Attempting to query by 0 day range, this would cause infinite looping.") - - query = self.sf._build_query_string(catalog_entry, singer_utils.strftime(start_date), - singer_utils.strftime(end_date)) - for record in self._query_recur( - query, - catalog_entry, - start_date_str, - end_date, - retries - 1): + raise TapSalesforceExceptionError( + "Attempting to query by 0 day range, this would cause infinite looping." + ) + + query = self.sf._build_query_string( + catalog_entry, + singer_utils.strftime(start_date), + singer_utils.strftime(end_date), + ) + for record in self._query_recur(query, catalog_entry, start_date_str, end_date, retries - 1): yield record def _sync_records(self, url, headers, params): while True: - resp = self.sf._make_request('GET', url, headers=headers, params=params) + resp = self.sf._make_request("GET", url, headers=headers, params=params) resp_json = resp.json() - for rec in resp_json.get('records'): - yield rec + yield from resp_json.get("records") - next_records_url = resp_json.get('nextRecordsUrl') + next_records_url = resp_json.get("nextRecordsUrl") if next_records_url is None: break else: - url = "{}{}".format(self.sf.instance_url, next_records_url) + url = f"{self.sf.instance_url}{next_records_url}" diff --git a/tap_salesforce/sync.py b/tap_salesforce/sync.py index dc1157e..f900e75 100644 --- a/tap_salesforce/sync.py +++ b/tap_salesforce/sync.py @@ -1,17 +1,21 @@ import time + import singer import singer.utils as singer_utils -from singer import Transformer, metadata, metrics from requests.exceptions import RequestException +from singer import Transformer, metadata, metrics + from tap_salesforce.salesforce.bulk import Bulk LOGGER = singer.get_logger() -BLACKLISTED_FIELDS = set(['attributes']) +BLACKLISTED_FIELDS = {"attributes"} + def remove_blacklisted_fields(data): return {k: v for k, v in data.items() if k not in BLACKLISTED_FIELDS} + # pylint: disable=unused-argument def transform_bulk_data_hook(data, typ, schema): result = data @@ -21,38 +25,42 @@ def transform_bulk_data_hook(data, typ, schema): # Salesforce Bulk API returns CSV's with empty strings for text fields. # When the text field is nillable and the data value is an empty string, # change the data so that it is None. - if data == "" and "null" in schema['type']: + if data == "" and "null" in schema["type"]: result = None return result + def get_stream_version(catalog_entry, state): - tap_stream_id = catalog_entry['tap_stream_id'] - catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + tap_stream_id = catalog_entry["tap_stream_id"] + catalog_metadata = metadata.to_map(catalog_entry["metadata"]) + replication_key = catalog_metadata.get((), {}).get("replication-key") - if singer.get_bookmark(state, tap_stream_id, 'version') is None: + if singer.get_bookmark(state, tap_stream_id, "version") is None: stream_version = int(time.time() * 1000) else: - stream_version = singer.get_bookmark(state, tap_stream_id, 'version') + stream_version = singer.get_bookmark(state, tap_stream_id, "version") if replication_key: return stream_version return int(time.time() * 1000) + def resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter): bulk = Bulk(sf) - current_bookmark = singer.get_bookmark(state, catalog_entry['tap_stream_id'], 'JobHighestBookmarkSeen') or sf.get_start_date(state, catalog_entry) + current_bookmark = singer.get_bookmark( + state, catalog_entry["tap_stream_id"], "JobHighestBookmarkSeen" + ) or sf.get_start_date(state, catalog_entry) current_bookmark = singer_utils.strptime_with_tz(current_bookmark) - batch_ids = singer.get_bookmark(state, catalog_entry['tap_stream_id'], 'BatchIDs') + batch_ids = singer.get_bookmark(state, catalog_entry["tap_stream_id"], "BatchIDs") start_time = singer_utils.now() - stream = catalog_entry['stream'] - stream_alias = catalog_entry.get('stream_alias') - catalog_metadata = metadata.to_map(catalog_entry.get('metadata')) - replication_key = catalog_metadata.get((), {}).get('replication-key') + stream = catalog_entry["stream"] + stream_alias = catalog_entry.get("stream_alias") + catalog_metadata = metadata.to_map(catalog_entry.get("metadata")) + replication_key = catalog_metadata.get((), {}).get("replication-key") stream_version = get_stream_version(catalog_entry, state) - schema = catalog_entry['schema'] + schema = catalog_entry["schema"] if not bulk.job_exists(job_id): LOGGER.info("Found stored Job ID that no longer exists, resetting bookmark and removing JobID from state.") @@ -67,28 +75,36 @@ def resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter): rec = fix_record_anytype(rec, schema) singer.write_message( singer.RecordMessage( - stream=( - stream_alias or stream), + stream=(stream_alias or stream), record=rec, version=stream_version, - time_extracted=start_time)) + time_extracted=start_time, + ) + ) # Update bookmark if necessary replication_key_value = replication_key and singer_utils.strptime_with_tz(rec[replication_key]) - if replication_key_value and replication_key_value <= start_time and replication_key_value > current_bookmark: + if ( + replication_key_value + and replication_key_value <= start_time + and replication_key_value > current_bookmark + ): current_bookmark = singer_utils.strptime_with_tz(rec[replication_key]) - state = singer.write_bookmark(state, - catalog_entry['tap_stream_id'], - 'JobHighestBookmarkSeen', - singer_utils.strftime(current_bookmark)) + state = singer.write_bookmark( + state, + catalog_entry["tap_stream_id"], + "JobHighestBookmarkSeen", + singer_utils.strftime(current_bookmark), + ) batch_ids.remove(batch_id) LOGGER.info("Finished syncing batch %s. Removing batch from state.", batch_id) LOGGER.info("Batches to go: %d", len(batch_ids)) singer.write_state(state) + def sync_stream(sf, catalog_entry, state, state_msg_threshold): - stream = catalog_entry['stream'] + stream = catalog_entry["stream"] with metrics.record_counter(stream) as counter: try: @@ -96,26 +112,24 @@ def sync_stream(sf, catalog_entry, state, state_msg_threshold): # Write the state generated for the last record generated by sf.query singer.write_state(state) except RequestException as ex: - raise Exception("Error syncing {}: {} Response: {}".format( - stream, ex, ex.response.text)) + raise Exception(f"Error syncing {stream}: {ex} Response: {ex.response.text}") # noqa: B904 except Exception as ex: - raise Exception("Error syncing {}: {}".format( - stream, ex)) from ex + raise Exception(f"Error syncing {stream}: {ex}") from ex + def sync_records(sf, catalog_entry, state, counter, state_msg_threshold): chunked_bookmark = singer_utils.strptime_with_tz(sf.get_start_date(state, catalog_entry)) - stream = catalog_entry['stream'] - schema = catalog_entry['schema'] - stream_alias = catalog_entry.get('stream_alias') - catalog_metadata = metadata.to_map(catalog_entry['metadata']) - replication_key = catalog_metadata.get((), {}).get('replication-key') + stream = catalog_entry["stream"] + schema = catalog_entry["schema"] + stream_alias = catalog_entry.get("stream_alias") + catalog_metadata = metadata.to_map(catalog_entry["metadata"]) + replication_key = catalog_metadata.get((), {}).get("replication-key") stream_version = get_stream_version(catalog_entry, state) - activate_version_message = singer.ActivateVersionMessage(stream=(stream_alias or stream), - version=stream_version) + activate_version_message = singer.ActivateVersionMessage(stream=(stream_alias or stream), version=stream_version) start_time = singer_utils.now() - LOGGER.info('Syncing Salesforce data for stream %s', stream) + LOGGER.info("Syncing Salesforce data for stream %s", stream) for rec in sf.query(catalog_entry, state): counter.increment() @@ -124,23 +138,29 @@ def sync_records(sf, catalog_entry, state, counter, state_msg_threshold): rec = fix_record_anytype(rec, schema) singer.write_message( singer.RecordMessage( - stream=( - stream_alias or stream), + stream=(stream_alias or stream), record=rec, version=stream_version, - time_extracted=start_time)) + time_extracted=start_time, + ) + ) replication_key_value = replication_key and singer_utils.strptime_with_tz(rec[replication_key]) if sf.pk_chunking: - if replication_key_value and replication_key_value <= start_time and replication_key_value > chunked_bookmark: + if ( + replication_key_value + and replication_key_value <= start_time + and replication_key_value > chunked_bookmark + ): # Replace the highest seen bookmark and save the state in case we need to resume later chunked_bookmark = singer_utils.strptime_with_tz(rec[replication_key]) state = singer.write_bookmark( state, - catalog_entry['tap_stream_id'], - 'JobHighestBookmarkSeen', - singer_utils.strftime(chunked_bookmark)) + catalog_entry["tap_stream_id"], + "JobHighestBookmarkSeen", + singer_utils.strftime(chunked_bookmark), + ) if counter.value % state_msg_threshold == 0: singer.write_state(state) @@ -149,9 +169,10 @@ def sync_records(sf, catalog_entry, state, counter, state_msg_threshold): elif replication_key_value and replication_key_value <= start_time: state = singer.write_bookmark( state, - catalog_entry['tap_stream_id'], + catalog_entry["tap_stream_id"], replication_key, - rec[replication_key]) + rec[replication_key], + ) if counter.value % state_msg_threshold == 0: singer.write_state(state) @@ -160,21 +181,23 @@ def sync_records(sf, catalog_entry, state, counter, state_msg_threshold): # activate_version message for the next sync if not replication_key: singer.write_message(activate_version_message) - state = singer.write_bookmark( - state, catalog_entry['tap_stream_id'], 'version', None) + state = singer.write_bookmark(state, catalog_entry["tap_stream_id"], "version", None) # If pk_chunking is set, only write a bookmark at the end if sf.pk_chunking: # Write a bookmark with the highest value we've seen state = singer.write_bookmark( state, - catalog_entry['tap_stream_id'], + catalog_entry["tap_stream_id"], replication_key, - singer_utils.strftime(chunked_bookmark)) + singer_utils.strftime(chunked_bookmark), + ) + def fix_record_anytype(rec, schema): """Modifies a record when the schema has no 'type' element due to a SF type of 'anyType.' Attempts to set the record's value for that element to an int, float, or string.""" + def try_cast(val, coercion): try: return coercion(val) @@ -182,12 +205,12 @@ def try_cast(val, coercion): return val for k, v in rec.items(): - if schema['properties'][k].get("type") is None: + if schema["properties"][k].get("type") is None: val = v val = try_cast(v, int) val = try_cast(v, float) if v in ["true", "false"]: - val = (v == "true") + val = v == "true" if v == "": val = None