Skip to content

Commit

Permalink
implement GrpcDataProvider.read_last_scalars; remove plugin_name arg
Browse files Browse the repository at this point in the history
  • Loading branch information
yatbear committed Oct 23, 2023
1 parent f30a9f8 commit 1d8e79c
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tensorboard/backend/application_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def list_scalars(self, ctx=None, *, experiment_id):
def read_scalars(self, ctx=None, *, experiment_id):
raise NotImplementedError()

def read_last_scalars(self, ctx=None, *, experiment_id):
raise NotImplementedError()


class HandlingErrorsTest(tb_test.TestCase):
def test_successful_response_passes_through(self):
Expand Down
37 changes: 37 additions & 0 deletions tensorboard/data/grpc_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""A data provider that talks to a gRPC server."""

import collections
import contextlib

import grpc
Expand Down Expand Up @@ -148,6 +149,42 @@ def read_scalars(
series.append(point)
return result

@timing.log_latency
def read_last_scalars(
self,
ctx,
*,
experiment_id,
run_tag_filter=None,
):
with timing.log_latency("build request"):
req = data_provider_pb2.ReadScalarsRequest()
req.experiment_id = experiment_id
_populate_rtf(run_tag_filter, req.run_tag_filter)
# `ReadScalars` always includes the most recent datum, therefore
# downsampling to one means fetching the latest value.
req.downsample.num_points = 1
with timing.log_latency("_stub.ReadScalars"):
with _translate_grpc_error():
res = self._stub.ReadScalars(req)
with timing.log_latency("build result"):
result = collections.defaultdict(dict)
for run_entry in res.runs:
run_name = run_entry.run_name
for tag_entry in run_entry.tags:
d = tag_entry.data
# There should be no more than one datum in
# `tag_entry.data` since downsample was set to 1.
for (step, wt, value) in zip(d.step, d.wall_time, d.value):
result[run_name][
tag_entry.tag_name
] = provider.ScalarDatum(
step=step,
wall_time=wt,
value=value,
)
return result

@timing.log_latency
def list_tensors(
self, ctx, *, experiment_id, plugin_name, run_tag_filter=None
Expand Down
52 changes: 52 additions & 0 deletions tensorboard/data/grpc_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,58 @@ def test_read_scalars(self):
req.downsample.num_points = 4
self.stub.ReadScalars.assert_called_once_with(req)

def test_read_last_scalars(self):
tag1 = data_provider_pb2.ReadScalarsResponse.TagEntry(
tag_name="tag1",
data=data_provider_pb2.ScalarData(
step=[10000], wall_time=[1234.0], value=[1]
),
)
tag2 = data_provider_pb2.ReadScalarsResponse.TagEntry(
tag_name="tag2",
data=data_provider_pb2.ScalarData(
step=[10000], wall_time=[1235.0], value=[0.50]
),
)
run1 = data_provider_pb2.ReadScalarsResponse.RunEntry(
run_name="run1", tags=[tag1]
)
run2 = data_provider_pb2.ReadScalarsResponse.RunEntry(
run_name="run2", tags=[tag2]
)
res = data_provider_pb2.ReadScalarsResponse(runs=[run1, run2])
self.stub.ReadScalars.return_value = res

actual = self.provider.read_last_scalars(
self.ctx,
experiment_id="123",
run_tag_filter=provider.RunTagFilter(
runs=["train", "test", "nope"]
),
)
expected = {
"run1": {
"tag1": provider.ScalarDatum(
step=10000, wall_time=1234.0, value=1
),
},
"run2": {
"tag2": provider.ScalarDatum(
step=10000, wall_time=1235.0, value=0.50
),
},
}

self.assertEqual(actual, expected)

req = data_provider_pb2.ReadScalarsRequest()
req.experiment_id = "123"
req.run_tag_filter.runs.names.extend(
["nope", "test", "train"]
) # sorted
req.downsample.num_points = 1
self.stub.ReadScalars.assert_called_once_with(req)

def test_list_tensors(self):
res = data_provider_pb2.ListTensorsResponse()
run1 = res.runs.add(run_name="val")
Expand Down
3 changes: 0 additions & 3 deletions tensorboard/data/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def read_scalars(
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
Expand All @@ -224,8 +223,6 @@ def read_scalars(
Args:
ctx: A TensorBoard `RequestContext` value.
experiment_id: ID of enclosing experiment.
plugin_name: String name of the TensorBoard plugin that created
the data to be queried. Required.
downsample: Integer number of steps to which to downsample the
results (e.g., `1000`). The most recent datum (last scalar)
should always be included. See `DataProvider` class docstring
Expand Down

0 comments on commit 1d8e79c

Please sign in to comment.