From 1d8e79c77968df69575c1b118c71e37369d2ebfd Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Mon, 23 Oct 2023 15:39:32 +0000 Subject: [PATCH] implement GrpcDataProvider.read_last_scalars; remove plugin_name arg --- tensorboard/backend/application_test.py | 3 ++ tensorboard/data/grpc_provider.py | 37 ++++++++++++++++++ tensorboard/data/grpc_provider_test.py | 52 +++++++++++++++++++++++++ tensorboard/data/provider.py | 3 -- 4 files changed, 92 insertions(+), 3 deletions(-) diff --git a/tensorboard/backend/application_test.py b/tensorboard/backend/application_test.py index 6b0499df5f..86f76f52f9 100644 --- a/tensorboard/backend/application_test.py +++ b/tensorboard/backend/application_test.py @@ -153,6 +153,9 @@ def list_scalars(self, ctx=None, *, experiment_id): def read_scalars(self, ctx=None, *, experiment_id): raise NotImplementedError() + def read_last_scalars(self, ctx=None, *, experiment_id): + raise NotImplementedError() + class HandlingErrorsTest(tb_test.TestCase): def test_successful_response_passes_through(self): diff --git a/tensorboard/data/grpc_provider.py b/tensorboard/data/grpc_provider.py index af879743f6..1d465d3c72 100644 --- a/tensorboard/data/grpc_provider.py +++ b/tensorboard/data/grpc_provider.py @@ -14,6 +14,7 @@ # ============================================================================== """A data provider that talks to a gRPC server.""" +import collections import contextlib import grpc @@ -148,6 +149,42 @@ def read_scalars( series.append(point) return result + @timing.log_latency + def read_last_scalars( + self, + ctx, + *, + experiment_id, + run_tag_filter=None, + ): + with timing.log_latency("build request"): + req = data_provider_pb2.ReadScalarsRequest() + req.experiment_id = experiment_id + _populate_rtf(run_tag_filter, req.run_tag_filter) + # `ReadScalars` always includes the most recent datum, therefore + # downsampling to one means fetching the latest value. + req.downsample.num_points = 1 + with timing.log_latency("_stub.ReadScalars"): + with _translate_grpc_error(): + res = self._stub.ReadScalars(req) + with timing.log_latency("build result"): + result = collections.defaultdict(dict) + for run_entry in res.runs: + run_name = run_entry.run_name + for tag_entry in run_entry.tags: + d = tag_entry.data + # There should be no more than one datum in + # `tag_entry.data` since downsample was set to 1. + for (step, wt, value) in zip(d.step, d.wall_time, d.value): + result[run_name][ + tag_entry.tag_name + ] = provider.ScalarDatum( + step=step, + wall_time=wt, + value=value, + ) + return result + @timing.log_latency def list_tensors( self, ctx, *, experiment_id, plugin_name, run_tag_filter=None diff --git a/tensorboard/data/grpc_provider_test.py b/tensorboard/data/grpc_provider_test.py index ed2136cc46..513e130d6f 100644 --- a/tensorboard/data/grpc_provider_test.py +++ b/tensorboard/data/grpc_provider_test.py @@ -230,6 +230,58 @@ def test_read_scalars(self): req.downsample.num_points = 4 self.stub.ReadScalars.assert_called_once_with(req) + def test_read_last_scalars(self): + tag1 = data_provider_pb2.ReadScalarsResponse.TagEntry( + tag_name="tag1", + data=data_provider_pb2.ScalarData( + step=[10000], wall_time=[1234.0], value=[1] + ), + ) + tag2 = data_provider_pb2.ReadScalarsResponse.TagEntry( + tag_name="tag2", + data=data_provider_pb2.ScalarData( + step=[10000], wall_time=[1235.0], value=[0.50] + ), + ) + run1 = data_provider_pb2.ReadScalarsResponse.RunEntry( + run_name="run1", tags=[tag1] + ) + run2 = data_provider_pb2.ReadScalarsResponse.RunEntry( + run_name="run2", tags=[tag2] + ) + res = data_provider_pb2.ReadScalarsResponse(runs=[run1, run2]) + self.stub.ReadScalars.return_value = res + + actual = self.provider.read_last_scalars( + self.ctx, + experiment_id="123", + run_tag_filter=provider.RunTagFilter( + runs=["train", "test", "nope"] + ), + ) + expected = { + "run1": { + "tag1": provider.ScalarDatum( + step=10000, wall_time=1234.0, value=1 + ), + }, + "run2": { + "tag2": provider.ScalarDatum( + step=10000, wall_time=1235.0, value=0.50 + ), + }, + } + + self.assertEqual(actual, expected) + + req = data_provider_pb2.ReadScalarsRequest() + req.experiment_id = "123" + req.run_tag_filter.runs.names.extend( + ["nope", "test", "train"] + ) # sorted + req.downsample.num_points = 1 + self.stub.ReadScalars.assert_called_once_with(req) + def test_list_tensors(self): res = data_provider_pb2.ListTensorsResponse() run1 = res.runs.add(run_name="val") diff --git a/tensorboard/data/provider.py b/tensorboard/data/provider.py index 47040e89b4..82fc498927 100644 --- a/tensorboard/data/provider.py +++ b/tensorboard/data/provider.py @@ -215,7 +215,6 @@ def read_scalars( ctx=None, *, experiment_id, - plugin_name, downsample=None, run_tag_filter=None, ): @@ -224,8 +223,6 @@ def read_scalars( Args: ctx: A TensorBoard `RequestContext` value. experiment_id: ID of enclosing experiment. - plugin_name: String name of the TensorBoard plugin that created - the data to be queried. Required. downsample: Integer number of steps to which to downsample the results (e.g., `1000`). The most recent datum (last scalar) should always be included. See `DataProvider` class docstring