diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index 3cf0d78e6..2fb9cf4d8 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -35,7 +35,7 @@ def test_sniff_format_for_parquet(): f.flush() assert _sniff_format_for_dataset(f.name) == ".parquet" - + @skip_if_no_soundlibs def test_resolve_audio_pointer(): @@ -56,15 +56,17 @@ def test_basic_parquet_datasource_read_row(): table = pa.Table.from_pydict(data) pq.write_table(table, f.name) - # Instantiate the ParquetDataSource datasource = ParquetDataSource([f.name]) + assert len(datasource.shard_names) == 1, "Expected only one shard" + shard_name = datasource.shard_names[0] + # sanity check: Read data starting from row 1 - row_data = list(datasource.open_shard_at_row(shard_name=f.name.replace(".", "_"), row=1)) + row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=1)) # Verify the output assert len(row_data) == 2 # We expect 2 rows starting from index 1 assert row_data[0]["column1"] == "value2" assert row_data[0]["column2"] == 20 assert row_data[1]["column1"] == "value3" - assert row_data[1]["column2"] == 30 \ No newline at end of file + assert row_data[1]["column2"] == 30