Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data: define interface for read_last_scalars #6657

Merged
merged 10 commits into from
Oct 30, 2023
3 changes: 3 additions & 0 deletions tensorboard/backend/application_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def list_scalars(self, ctx=None, *, experiment_id):
def read_scalars(self, ctx=None, *, experiment_id):
raise NotImplementedError()

def read_last_scalars(self, ctx=None, *, experiment_id, plugin_name):
raise NotImplementedError()


class HandlingErrorsTest(tb_test.TestCase):
def test_successful_response_passes_through(self):
Expand Down
25 changes: 25 additions & 0 deletions tensorboard/backend/event_processing/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import base64
import collections
import json
import random

Expand Down Expand Up @@ -137,6 +138,30 @@ def read_scalars(
)
return self._read(_convert_scalar_event, index, downsample)

def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
run_tag_to_last_scalar_datum = collections.defaultdict(dict)
for (run, tags_for_run) in index.items():
for (tag, metadata) in tags_for_run.items():
events = self._multiplexer.Tensors(run, tag)
if events:
run_tag_to_last_scalar_datum[run][
tag
] = _convert_scalar_event(events[-1])

return run_tag_to_last_scalar_datum

def list_tensors(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
Expand Down
36 changes: 36 additions & 0 deletions tensorboard/backend/event_processing/data_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,42 @@ def test_read_scalars_but_not_rank_0(self):
downsample=100,
)

def test_read_last_scalars(self):
multiplexer = self.create_multiplexer()
provider = data_provider.MultiplexerDataProvider(
multiplexer, self.logdir
)

run_tag_filter = base_provider.RunTagFilter(
runs=["waves", "polynomials", "unicorns"],
tags=["sine", "square", "cube", "iridescence"],
)
result = provider.read_last_scalars(
self.ctx,
experiment_id="unused",
plugin_name=scalar_metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
)

self.assertCountEqual(result.keys(), ["polynomials", "waves"])
self.assertCountEqual(result["polynomials"].keys(), ["square", "cube"])
self.assertCountEqual(result["waves"].keys(), ["square", "sine"])
for run in result:
for tag in result[run]:
events = multiplexer.Tensors(run, tag)
if events:
last_event = events[-1]
datum = result[run][tag]
self.assertIsInstance(datum, base_provider.ScalarDatum)
self.assertEqual(datum.step, last_event.step)
self.assertEqual(datum.wall_time, last_event.wall_time)
self.assertEqual(
datum.value,
tensor_util.make_ndarray(
last_event.tensor_proto
).item(),
)

def test_list_tensors_all(self):
provider = self.create_provider()
result = provider.list_tensors(
Expand Down
39 changes: 39 additions & 0 deletions tensorboard/data/grpc_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""A data provider that talks to a gRPC server."""

import collections
import contextlib

import grpc
Expand Down Expand Up @@ -148,6 +149,44 @@ def read_scalars(
series.append(point)
return result

@timing.log_latency
def read_last_scalars(
self,
ctx,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
with timing.log_latency("build request"):
req = data_provider_pb2.ReadScalarsRequest()
req.experiment_id = experiment_id
req.plugin_filter.plugin_name = plugin_name
_populate_rtf(run_tag_filter, req.run_tag_filter)
# `ReadScalars` always includes the most recent datum, therefore
# downsampling to one means fetching the latest value.
req.downsample.num_points = 1
with timing.log_latency("_stub.ReadScalars"):
with _translate_grpc_error():
res = self._stub.ReadScalars(req)
with timing.log_latency("build result"):
result = collections.defaultdict(dict)
for run_entry in res.runs:
run_name = run_entry.run_name
for tag_entry in run_entry.tags:
d = tag_entry.data
# There should be no more than one datum in
# `tag_entry.data` since downsample was set to 1.
for (step, wt, value) in zip(d.step, d.wall_time, d.value):
result[run_name][
tag_entry.tag_name
] = provider.ScalarDatum(
step=step,
wall_time=wt,
value=value,
)
return result

@timing.log_latency
def list_tensors(
self, ctx, *, experiment_id, plugin_name, run_tag_filter=None
Expand Down
54 changes: 54 additions & 0 deletions tensorboard/data/grpc_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,60 @@ def test_read_scalars(self):
req.downsample.num_points = 4
self.stub.ReadScalars.assert_called_once_with(req)

def test_read_last_scalars(self):
tag1 = data_provider_pb2.ReadScalarsResponse.TagEntry(
tag_name="tag1",
data=data_provider_pb2.ScalarData(
step=[10000], wall_time=[1234.0], value=[1]
),
)
tag2 = data_provider_pb2.ReadScalarsResponse.TagEntry(
tag_name="tag2",
data=data_provider_pb2.ScalarData(
step=[10000], wall_time=[1235.0], value=[0.50]
),
)
run1 = data_provider_pb2.ReadScalarsResponse.RunEntry(
run_name="run1", tags=[tag1]
)
run2 = data_provider_pb2.ReadScalarsResponse.RunEntry(
run_name="run2", tags=[tag2]
)
res = data_provider_pb2.ReadScalarsResponse(runs=[run1, run2])
self.stub.ReadScalars.return_value = res

actual = self.provider.read_last_scalars(
self.ctx,
experiment_id="123",
plugin_name="scalars",
run_tag_filter=provider.RunTagFilter(
runs=["train", "test", "nope"]
),
)
expected = {
"run1": {
"tag1": provider.ScalarDatum(
step=10000, wall_time=1234.0, value=1
),
},
"run2": {
"tag2": provider.ScalarDatum(
step=10000, wall_time=1235.0, value=0.50
),
},
}

self.assertEqual(actual, expected)

req = data_provider_pb2.ReadScalarsRequest()
yatbear marked this conversation as resolved.
Show resolved Hide resolved
req.experiment_id = "123"
req.plugin_filter.plugin_name = "scalars"
req.run_tag_filter.runs.names.extend(
["nope", "test", "train"]
) # sorted
req.downsample.num_points = 1
self.stub.ReadScalars.assert_called_once_with(req)

def test_list_tensors(self):
res = data_provider_pb2.ListTensorsResponse()
run1 = res.runs.add(run_name="val")
Expand Down
34 changes: 34 additions & 0 deletions tensorboard/data/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,40 @@ def read_scalars(
"""
pass

@abc.abstractmethod
def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're requiring plugin_name to be specified, I would think having the method be read_latest_values, or something like this, would be more generic/flexible, and we can say in the doc string that it's possible that not all plugins are supported by all implementations.

Although the return value does get a bit more difficult to specify and deal with, so if you prefer this because of this reason, I'm ok with it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I specified plugin_name because read_scalars also requires plugin_name. I prefer the simpler definition for just the scalars, in this case do you think we can just remove the plugin_name arg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to remove the plugin_name completely. I think it's unlikely that we'll have a new plugin type that contains scalar data in the near future. And since scalar seems to be the most used type, we don't have to generalize it to read_last_values just yet, wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry for the late response. I thought I had replied to this. Technically, the plugin type is somewhat independent from the data type, e.g. there's a custom_scalars plugin in our examples, I believe, which also uses scalar values.

For this reason, and to be consistent with the rest of the interface for the data provider, I'd lean a bit towards keeping the plugin_name... but this is ok too. I agree it's unlikely we'll ever need anything other than "scalars", and if we do, we can change it later (but again, unlikely).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, added it back.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I didn't realize I also needed to implement this method in event_processing/data_provider. PTAL, thanks!

run_tag_filter=None,
):
"""Read the most recent values from scalar time series.

Args:
ctx: A TensorBoard `RequestContext` value.
experiment_id: ID of enclosing experiment.
plugin_name: String name of the TensorBoard plugin that created
the data to be queried. Required.
run_tag_filter: Optional `RunTagFilter` value. If provided, a datum
series will only be included in the result if its run and tag
both pass this filter. If `None`, all time series will be
included.

The result will only contain keys for run-tag combinations that
actually exist, which may not include all entries in the
`run_tag_filter`.

Returns:
A nested map `d` such that `d[run][tag]` is a `ScalarDatum`
representing the latest scalar in the time series.

Raises:
tensorboard.errors.PublicError: See `DataProvider` class docstring.
"""
pass

def list_tensors(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
Expand Down
Loading