diff --git a/sqlframe/base/mixins/readwriter_mixins.py b/sqlframe/base/mixins/readwriter_mixins.py index a5029f6..eb69dda 100644 --- a/sqlframe/base/mixins/readwriter_mixins.py +++ b/sqlframe/base/mixins/readwriter_mixins.py @@ -75,7 +75,7 @@ def load( assert path is not None, "path is required" assert isinstance(path, str), "path must be a string" - format = format or _infer_format(path) + format = format or self.state_format_to_read or _infer_format(path) kwargs = {k: v for k, v in options.items() if v is not None} if format == "json": df = pd.read_json(path, lines=True, **kwargs) # type: ignore diff --git a/sqlframe/base/readerwriter.py b/sqlframe/base/readerwriter.py index 531879d..57b8057 100644 --- a/sqlframe/base/readerwriter.py +++ b/sqlframe/base/readerwriter.py @@ -36,6 +36,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]): def __init__(self, spark: SESSION): self._session = spark + self.state_format_to_read: t.Optional[str] = None @property def session(self) -> SESSION: @@ -67,6 +68,44 @@ def _to_casted_columns(self, column_mapping: t.Dict) -> t.List[Column]: for k, v in column_mapping.items() ] + def format(self, source: str) -> "Self": + """Specifies the input data source format. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + source : str + string, name of the data source, e.g. 'json', 'parquet'. + + Examples + -------- + >>> spark.read.format('json') + <...readwriter.DataFrameReader object ...> + + Write a DataFrame into a JSON file and read it back. + + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a DataFrame into a JSON file + ... spark.createDataFrame( + ... [{"age": 100, "name": "Hyukjin Kwon"}] + ... ).write.mode("overwrite").format("json").save(d) + ... + ... # Read the JSON file as a DataFrame. + ... spark.read.format('json').load(d).show() + +---+------------+ + |age| name| + +---+------------+ + |100|Hyukjin Kwon| + +---+------------+ + """ + self.state_format_to_read = source + return self + def load( self, path: t.Optional[PathOrPaths] = None, diff --git a/sqlframe/duckdb/readwriter.py b/sqlframe/duckdb/readwriter.py index 498ce41..73aae08 100644 --- a/sqlframe/duckdb/readwriter.py +++ b/sqlframe/duckdb/readwriter.py @@ -72,6 +72,7 @@ def load( |100|NULL| +---+----+ """ + format = format or self.state_format_to_read if schema: column_mapping = ensure_column_mapping(schema) select_column_mapping = column_mapping.copy()