Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed RDF tests with RNTuple #1105

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions python/distrdf/backends/check_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def init(value):

df = df.Define("u", "userValue").Histo1D(
("name", "title", 1, 100, 130), "u")

h = df.GetValue()
assert h.GetMean() == 123

Expand All @@ -133,22 +133,23 @@ class TestEmptyTreeError:
Tests with emtpy trees.
"""

def test_histo_from_empty_root_file(self, payload):
@pytest.mark.parametrize("datasource", ["ttree","rntuple"])
def test_histo_from_empty_root_file(self, payload, datasource):
"""
Check that when performing operations with the distributed backend on
an RDataFrame without entries, DistRDF raises an error.
"""

connection, backend = payload
datasetname = "empty"
filename = f"../data/{datasource}/empty.root"
# Create an RDataFrame from a file with an empty tree
if backend == "dask":
RDataFrame = ROOT.RDF.Experimental.Distributed.Dask.RDataFrame
rdf = RDataFrame(
"empty", "../data/ttree/empty.root", daskclient=connection)
rdf = RDataFrame(datasetname, filename, daskclient=connection)
elif backend == "spark":
RDataFrame = ROOT.RDF.Experimental.Distributed.Spark.RDataFrame
rdf = RDataFrame("empty", "../data/ttree/empty.root",
sparkcontext=connection)
rdf = RDataFrame(datasetname, filename, sparkcontext=connection)
histo = rdf.Histo1D(("empty", "empty", 10, 0, 10), "mybranch")

# Get entries in the histogram, raises error
Expand All @@ -161,7 +162,6 @@ def test_count_with_some_empty_trees(self, payload):
not contribute to how many entries are processed in the distributed
execution.
"""

connection, backend = payload
treenames = [f"tree_{i}" for i in range(3)]
filenames = [
Expand Down Expand Up @@ -200,21 +200,22 @@ class TestWithRepeatedTree:
is used multiple times.
"""

def test_count_with_same_tree_repeated(self, payload):
@pytest.mark.parametrize("datasource", ["ttree","rntuple"])
def test_count_with_same_tree_repeated(self, payload, datasource):
"""
Count entries of a dataset with three times the same tree.
"""
connection, backend = payload
treename = "tree_0"
filename = "../data/ttree/distrdf_roottest_check_backend_0.root"
datasetname = "tree_0"
filename = f"../data/{datasource}/distrdf_roottest_check_backend_0.root"
filenames = [filename] * 3

if backend == "dask":
RDataFrame = ROOT.RDF.Experimental.Distributed.Dask.RDataFrame
rdf = RDataFrame(treename, filenames, daskclient=connection)
rdf = RDataFrame(datasetname, filenames, daskclient=connection)
elif backend == "spark":
RDataFrame = ROOT.RDF.Experimental.Distributed.Spark.RDataFrame
rdf = RDataFrame(treename, filenames, sparkcontext=connection)
rdf = RDataFrame(datasetname, filenames, sparkcontext=connection)
assert rdf.Count().GetValue() == 300


Expand Down
22 changes: 15 additions & 7 deletions python/distrdf/backends/check_cloned_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,39 @@ class TestAsNumpy:
distributed configurations.
"""

@pytest.mark.parametrize("nparts", range(1, 21))
def test_clone_asnumpyresult(self, payload, nparts):
@pytest.mark.parametrize("nfiles", [1, 3, 7])
@pytest.mark.parametrize("nparts", [1, 2, 3, 7, 8, 15, 16, 21])
@pytest.mark.parametrize("datasource", ["ttree", "rntuple"])
def test_clone_asnumpyresult(self, payload, nfiles, nparts, datasource):
"""
Test that the correct values of the numpy array are retrieved from
distributed execution irrespective of the number of partitions.
"""

datasetname = "Events"
filename = "../data/ttree/distrdf_roottest_check_cloned_actions_asnumpy.root"
filename = f"../data/{datasource}/distrdf_roottest_check_cloned_actions_asnumpy.root"
inputfiles = [filename] * nfiles
connection, backend = payload
if backend == "dask":
RDataFrame = ROOT.RDF.Experimental.Distributed.Dask.RDataFrame
distrdf = RDataFrame(datasetname, filename,
distrdf = RDataFrame(datasetname, inputfiles,
daskclient=connection, npartitions=nparts)
elif backend == "spark":
RDataFrame = ROOT.RDF.Experimental.Distributed.Spark.RDataFrame
distrdf = RDataFrame(datasetname, filename,
distrdf = RDataFrame(datasetname, inputfiles,
sparkcontext=connection, npartitions=nparts)

localdf = ROOT.RDataFrame("Events", filename)
localdf = ROOT.RDataFrame("Events", inputfiles)

vals_distrdf = distrdf.AsNumpy(["event"])
vals_localdf = localdf.AsNumpy(["event"])

assert all(vals_localdf["event"] == numpy.sort(vals_distrdf["event"]))
# Distributed mode does not guarantee the order of execution of the tasks
# thus the output numpy array is unsorted. We also sort the output array
# of the local execution so that in case we test with multiple files
# the values of the arrays can be properly aligned (otherwise it would
# always fail).
assert all(numpy.sort(vals_localdf["event"]) == numpy.sort(vals_distrdf["event"]))


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions python/distrdf/backends/check_histo_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class TestDaskHistoWrite:
gaus_stdev = 1
delta_equal = 0.01

def test_write_histo(self, payload):
@pytest.mark.parametrize("datasource", ["ttree","rntuple"])
def test_write_histo(self, payload, datasource):
"""
Tests that an histogram is correctly written to a .root file created
before the execution of the event loop.
Expand All @@ -26,7 +27,7 @@ def test_write_histo(self, payload):
with ROOT.TFile("out_file.root", "recreate") as outfile:
# We can reuse the same dataset from another test
treename = "T"
filename = "../data/ttree/distrdf_roottest_check_friend_trees_main.root"
filename = f"../data/{datasource}/distrdf_roottest_check_friend_trees_main.root"
# Create a DistRDF RDataFrame with the parent and the friend trees
connection, backend = payload
if backend == "dask":
Expand Down
10 changes: 6 additions & 4 deletions python/distrdf/backends/check_reducer_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,13 @@ def test_redefine_one_column(self, payload):
assert sum_before.GetValue() == 10.0
assert sum_after.GetValue() == 20.0

def test_distributed_stddev(self, payload):
@pytest.mark.parametrize("datasource", ["ttree","rntuple"])
def test_distributed_stddev(self, payload, datasource):
"""Test support for the StdDev action."""

# Create dataset with fixed series of entries
treename = "tree"
filename = "../data/ttree/distrdf_roottest_check_reducer_merge_1.root"
filename = f"../data/{datasource}/distrdf_roottest_check_reducer_merge_1.root"

connection, backend = payload
if backend == "dask":
Expand All @@ -511,11 +512,12 @@ def test_distributed_stddev(self, payload):

assert std.GetValue() == pytest.approx(expected, rel), f"{std.GetValue()}!={expected}"

def test_distributed_stats(self, payload):
@pytest.mark.parametrize("datasource", ["ttree","rntuple"])
def test_distributed_stats(self, payload, datasource):
"""Test support for the Stats action."""
# Create dataset with fixed series of entries
treename = "tree"
filename = "../data/ttree/distrdf_roottest_check_reducer_merge_1.root"
filename = f"../data/{datasource}/distrdf_roottest_check_reducer_merge_1.root"

connection, backend = payload
if backend == "dask":
Expand Down
5 changes: 3 additions & 2 deletions python/distrdf/backends/check_rungraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
class TestRunGraphs:
"""Tests usage of RunGraphs function with Dask backend"""

def test_rungraphs_dask_3histos(self, payload):
@pytest.mark.parametrize("datasource", ["ttree","rntuple"])
def test_rungraphs_dask_3histos(self, payload, datasource):
"""
Submit three different Dask RDF graphs concurrently
"""
# Create a test file for processing
treename = "tree"
filename = "../data/ttree/distrdf_roottest_check_rungraphs.root"
filename = f"../data/{datasource}/distrdf_roottest_check_rungraphs.root"
nentries = 10000
connection, backend = payload
if backend == "dask":
Expand Down
Loading