
| Current Path : /usr/local/lib/python3.8/dist-packages/tensorboard/data/experimental/ |
Linux ift1.ift-informatik.de 5.4.0-216-generic #236-Ubuntu SMP Fri Apr 11 19:53:21 UTC 2025 x86_64 |
| Current File : //usr/local/lib/python3.8/dist-packages/tensorboard/data/experimental/experiment_from_dev.py |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experiment Data Access API for tensorboard.dev."""
import sys
import time
import grpc
import numpy as np
from tensorboard.data.experimental import base_experiment
from tensorboard.uploader import auth
from tensorboard.uploader import util
from tensorboard.uploader import server_info as server_info_lib
from tensorboard.uploader.proto import export_service_pb2
from tensorboard.uploader.proto import export_service_pb2_grpc
from tensorboard.uploader.proto import server_info_pb2
from tensorboard.util import grpc_util
DEFAULT_ORIGIN = "https://tensorboard.dev"
def import_pandas():
"""Import pandas, guarded by a user-friendly error message on failure."""
try:
import pandas
except ImportError:
raise ImportError(
"The get_scalars() feature requires the pandas package, "
"which does not seem to be available in your Python "
"environment. You can install it with command:\n\n"
" pip install pandas\n"
)
return pandas
class ExperimentFromDev(base_experiment.BaseExperiment):
"""Implementation of BaseExperiment, specialized for tensorboard.dev."""
def __init__(self, experiment_id, api_endpoint=None):
"""Constructor of ExperimentFromDev.
Args:
experiment_id: String ID of the experiment on tensorboard.dev (e.g.,
"AdYd1TgeTlaLWXx6I8JUbA").
api_endpoint: Optional override value for API endpoint. Used for
development only.
"""
super(ExperimentFromDev, self).__init__()
self._experiment_id = experiment_id
self._api_client = get_api_client(api_endpoint=api_endpoint)
def get_scalars(
self,
runs_filter=None,
tags_filter=None,
pivot=False,
include_wall_time=False,
):
# NOTE(#3650): Import pandas early in this method, so if the
# Python environment does not have pandas installed, an error can be
# raised early, before any rpc call is made.
pandas = import_pandas()
if runs_filter is not None:
raise NotImplementedError(
"runs_filter support for get_scalars() is not implemented yet."
)
if tags_filter is not None:
raise NotImplementedError(
"tags_filter support for get_scalars() is not implemented yet."
)
request = export_service_pb2.StreamExperimentDataRequest()
request.experiment_id = self._experiment_id
read_time = time.time()
util.set_timestamp(request.read_timestamp, read_time)
# TODO(cais, wchargin): Use another rpc to check for staleness and avoid
# a new StreamExperimentData rpc request if data is not stale.
stream = self._api_client.StreamExperimentData(
request, metadata=grpc_util.version_metadata()
)
runs = []
tags = []
steps = []
wall_times = []
values = []
for response in stream:
# TODO(cais, wchargin): Display progress bar during data loading.
num_values = len(response.points.values)
runs.extend([response.run_name] * num_values)
tags.extend([response.tag_name] * num_values)
steps.extend(list(response.points.steps))
wall_times.extend(
[t.ToNanoseconds() / 1e9 for t in response.points.wall_times]
)
values.extend(list(response.points.values))
data = {
"run": runs,
"tag": tags,
"step": steps,
"value": values,
}
if include_wall_time:
data["wall_time"] = wall_times
dataframe = pandas.DataFrame(data)
if pivot:
dataframe = self._pivot_dataframe(dataframe)
return dataframe
def _pivot_dataframe(self, dataframe):
num_missing_0 = np.count_nonzero(dataframe.isnull().values)
dataframe = dataframe.pivot_table(
values=(
["value", "wall_time"]
if "wall_time" in dataframe.columns
else "value"
),
index=["run", "step"],
columns="tag",
dropna=False,
)
num_missing_1 = np.count_nonzero(dataframe.isnull().values)
if num_missing_1 > num_missing_0:
raise ValueError(
"pivoted DataFrame contains missing value(s). "
"This is likely due to two timeseries having different "
"sets of steps in your experiment. "
"You can avoid this error by calling `get_scalars()` with "
"`pivot=False` to disable the DataFrame pivoting."
)
# `reset_index()` removes the MultiIndex structure of the pivoted
# DataFrame. Before the call, the DataFrame consits of two levels
# of index: "run" and "step". After the call, the index become a
# single range index (e.g,. `dataframe[:2]` works).
dataframe = dataframe.reset_index()
# Remove the columns name "tag".
dataframe.columns.name = None
dataframe.columns.names = [None for name in dataframe.columns.names]
return dataframe
def get_api_client(api_endpoint=None):
server_info = _get_server_info(api_endpoint=api_endpoint)
_handle_server_info(server_info)
channel_creds = grpc.ssl_channel_credentials()
credentials = auth.CredentialsStore().read_credentials()
if credentials:
channel_creds = grpc.composite_channel_credentials(
channel_creds, auth.id_token_call_credentials(credentials)
)
channel = grpc.secure_channel(
server_info.api_server.endpoint, channel_creds
)
return export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)
def _get_server_info(api_endpoint=None):
# TODO(cais): Add more plugins to the list when more plugin/data types
# are supported
plugins = ["scalars"]
if api_endpoint:
return server_info_lib.create_server_info(
DEFAULT_ORIGIN, api_endpoint, plugins
)
return server_info_lib.fetch_server_info(DEFAULT_ORIGIN, plugins)
def _handle_server_info(info):
compat = info.compatibility
if compat.verdict == server_info_pb2.VERDICT_WARN:
sys.stderr.write("Warning [from server]: %s\n" % compat.details)
sys.stderr.flush()
elif compat.verdict == server_info_pb2.VERDICT_ERROR:
raise ValueError("Error [from server]: %s" % compat.details)