diff --git a/pipit/tests/trace.py b/pipit/tests/trace.py index 7731fe61..34607e1f 100644 --- a/pipit/tests/trace.py +++ b/pipit/tests/trace.py @@ -216,3 +216,50 @@ def test_time_profile(data_dir, ping_pong_otf2_trace): assert np.isclose(norm.loc[61]["MPI_Comm_size"], 0.0) assert np.isclose(norm.loc[61]["MPI_Comm_rank"], 0.0) assert np.isclose(norm.loc[61]["MPI_Finalize"], 0.01614835) + + +def generic_test_message_histogram(trace): + message_histogram, bin_ranges = trace.message_histogram(bins=40) + print(message_histogram) + # check the length + assert len(message_histogram) == 40 + + # comm_matrix was already subject to test routines earlier + comm_matrix_count = trace.comm_matrix(output="count") + total_messages = comm_matrix_count.sum() + + # ensure that all messages are included + assert message_histogram.sum() == total_messages + + # ensure that the bin endpoints increase monotonically + assert np.all(bin_ranges[1:] > bin_ranges[:-1]) + + comm_matrix = trace.comm_matrix() + total_volume = comm_matrix.sum() + + # multiply each count by the right endpoint of the bin to estimate total volume + # this should always be >= actual total volume + # when done with left endpoint, always <= actual total volume + + upper_total_volume = np.dot(message_histogram, bin_ranges[1:]) + lower_total_volume = np.dot(message_histogram, bin_ranges[:-1]) + assert upper_total_volume >= total_volume >= lower_total_volume + + +def test_message_histogram(ping_pong_otf2_trace): + trace = Trace.from_otf2(str(ping_pong_otf2_trace)) + trace.calc_exc_metrics(["Timestamp (ns)"]) + generic_test_message_histogram(trace) + + # smallest message in ping-pong was 16384 B + # largest message was 2097152 B + message_histogram, bin_ranges = trace.message_histogram(bins=40) + assert bin_ranges[0] == 16384 + assert bin_ranges[-1] == 2097152 + + # all communications occur in a pair + assert np.all(message_histogram % 2 == 0) + + # only the first bin can have more than one pair in it + # all the others are either 0 or 2 in ping-pong + assert np.max(message_histogram[1:]) == 2