diff --git a/parsons/databases/redshift/redshift.py b/parsons/databases/redshift/redshift.py index fff64c4dbd..64de0a13e3 100644 --- a/parsons/databases/redshift/redshift.py +++ b/parsons/databases/redshift/redshift.py @@ -711,6 +711,7 @@ def unload( max_file_size="6.2 GB", extension=None, aws_region=None, + format=None, aws_access_key_id=None, aws_secret_access_key=None, ): @@ -761,6 +762,8 @@ def unload( The AWS Region where the target Amazon S3 bucket is located. REGION is required for UNLOAD to an Amazon S3 bucket that is not in the same AWS Region as the Amazon Redshift cluster. + format: str + The format of the unload file (CSV, PARQUET, JSON) - Optional. aws_access_key_id: An AWS access key granted to the bucket where the file is located. Not required if keys are stored as environmental variables. @@ -780,26 +783,34 @@ def unload( PARALLEL {parallel} \n MAXFILESIZE {max_file_size} """ - if manifest: - statement += "MANIFEST \n" - if header: - statement += "HEADER \n" - if delimiter: - statement += f"DELIMITER as '{delimiter}' \n" - if compression: - statement += f"{compression.upper()} \n" - if add_quotes: - statement += "ADDQUOTES \n" - if null_as: - statement += f"NULL {null_as} \n" - if escape: - statement += "ESCAPE \n" - if allow_overwrite: - statement += "ALLOWOVERWRITE \n" - if extension: - statement += f"EXTENSION '{extension}' \n" - if aws_region: - statement += f"REGION {aws_region} \n" + statement += "ALLOWOVERWRITE \n" if allow_overwrite else "" + statement += f"REGION {aws_region} \n" if aws_region else "" + statement += "MANIFEST \n" if manifest else "" + statement += f"EXTENSION '{extension}' \n" if extension else "" + + # Format-specific parameters + if format: + format = format.lower() + if format == "csv": + statement += f"DELIMITER AS '{delimiter}' \n" if delimiter else "" + statement += f"NULL AS '{null_as}' \n" if null_as else "" + statement += "HEADER \n" if header else "" + statement += "ESCAPE \n" if escape else "" + statement += "FORMAT AS CSV \n" + statement += f"{compression.upper()} \n" if compression else "" + elif format == "parquet": + statement += "FORMAT AS PARQUET \n" + elif format == "json": + statement += "FORMAT AS JSON \n" + statement += f"{compression.upper()} \n" if compression else "" + else: + # Default text file settings + statement += f"DELIMITER AS '{delimiter}' \n" if delimiter else "" + statement += "ADDQUOTES \n" if add_quotes else "" + statement += f"NULL AS '{null_as}' \n" if null_as else "" + statement += "ESCAPE \n" if escape else "" + statement += "HEADER \n" if header else "" + statement += f"{compression.upper()} \n" if compression else "" logger.info(f"Unloading data to s3://{bucket}/{key_prefix}") # Censor sensitive data @@ -847,7 +858,6 @@ def drop_and_unload( None """ query_end = "cascade" if cascade else "" - self.unload( sql=f"select * from {rs_table}", bucket=bucket, diff --git a/test/test_databases/test_redshift.py b/test/test_databases/test_redshift.py index a42e824258..ac150929ee 100644 --- a/test/test_databases/test_redshift.py +++ b/test/test_databases/test_redshift.py @@ -595,6 +595,50 @@ def test_unload(self): # Check that files are there self.assertTrue(self.s3.key_exists(self.temp_s3_bucket, "unload_test")) + def test_unload_json_format(self): + # Setup + self.rs.copy(self.tbl, f"{self.temp_schema}.test_copy", if_exists="drop") + + # Unload with JSON format + self.rs.unload( + f"select * from {self.temp_schema}.test_copy", + self.temp_s3_bucket, + "unload_test_json", + format="json", + ) + + # Check that files are there + self.assertTrue(self.s3.key_exists(self.temp_s3_bucket, "unload_test_json")) + + def test_unload_parquet_format(self): + # Setup + self.rs.copy(self.tbl, f"{self.temp_schema}.test_copy", if_exists="drop") + + # Unload with Parquet format + self.rs.unload( + f"select * from {self.temp_schema}.test_copy", + self.temp_s3_bucket, + "unload_test_parquet", + format="parquet", + ) + + self.assertTrue(self.s3.key_exists(self.temp_s3_bucket, "unload_test_parquet")) + + def test_unload_csv_format(self): + # Setup + self.rs.copy(self.tbl, f"{self.temp_schema}.test_copy", if_exists="drop") + + # Unload with Parquet format + self.rs.unload( + f"select * from {self.temp_schema}.test_copy", + self.temp_s3_bucket, + "unload_test_csv", + format="csv", + ) + + # Check that files are there + self.assertTrue(self.s3.key_exists(self.temp_s3_bucket, "unload_test_csv")) + def test_drop_and_unload(self): rs_table_test = f"{self.temp_schema}.test_copy"