# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A data provider that talks to a gRPC server."""

import contextlib

import grpc

from tensorboard.util import tensor_util
from tensorboard.util import timing
from tensorboard import errors
from tensorboard.data import provider
from tensorboard.data.proto import data_provider_pb2
from tensorboard.data.proto import data_provider_pb2_grpc


def make_stub(channel):
    """Wraps a gRPC channel with a service stub."""
    return data_provider_pb2_grpc.TensorBoardDataProviderStub(channel)


class GrpcDataProvider(provider.DataProvider):
    """Data provider that talks over gRPC."""

    def __init__(self, addr, stub):
        """Initializes a GrpcDataProvider.

        Args:
          addr: String address of the remote peer. Used cosmetically for
            data location.
          stub: `data_provider_pb2_grpc.TensorBoardDataProviderStub`
            value. See `make_stub` to construct one from a channel.
        """
        self._addr = addr
        self._stub = stub

    def __str__(self):
        return "GrpcDataProvider(addr=%r)" % self._addr

    def experiment_metadata(self, ctx, *, experiment_id):
        req = data_provider_pb2.GetExperimentRequest()
        req.experiment_id = experiment_id
        with _translate_grpc_error():
            res = self._stub.GetExperiment(req)
        res = provider.ExperimentMetadata(
            data_location=res.data_location,
            experiment_name=res.name,
            experiment_description=res.description,
            creation_time=_timestamp_proto_to_float(res.creation_time),
        )
        return res

    def list_plugins(self, ctx, *, experiment_id):
        req = data_provider_pb2.ListPluginsRequest()
        req.experiment_id = experiment_id
        with _translate_grpc_error():
            res = self._stub.ListPlugins(req)
        return [p.name for p in res.plugins]

    def list_runs(self, ctx, *, experiment_id):
        req = data_provider_pb2.ListRunsRequest()
        req.experiment_id = experiment_id
        with _translate_grpc_error():
            res = self._stub.ListRuns(req)
        return [
            provider.Run(
                run_id=run.name,
                run_name=run.name,
                start_time=run.start_time,
            )
            for run in res.runs
        ]

    @timing.log_latency
    def list_scalars(
        self, ctx, *, experiment_id, plugin_name, run_tag_filter=None
    ):
        with timing.log_latency("build request"):
            req = data_provider_pb2.ListScalarsRequest()
            req.experiment_id = experiment_id
            req.plugin_filter.plugin_name = plugin_name
            _populate_rtf(run_tag_filter, req.run_tag_filter)
        with timing.log_latency("_stub.ListScalars"):
            with _translate_grpc_error():
                res = self._stub.ListScalars(req)
        with timing.log_latency("build result"):
            result = {}
            for run_entry in res.runs:
                tags = {}
                result[run_entry.run_name] = tags
                for tag_entry in run_entry.tags:
                    time_series = tag_entry.metadata
                    tags[tag_entry.tag_name] = provider.ScalarTimeSeries(
                        max_step=time_series.max_step,
                        max_wall_time=time_series.max_wall_time,
                        plugin_content=time_series.summary_metadata.plugin_data.content,
                        description=time_series.summary_metadata.summary_description,
                        display_name=time_series.summary_metadata.display_name,
                    )
            return result

    @timing.log_latency
    def read_scalars(
        self,
        ctx,
        *,
        experiment_id,
        plugin_name,
        downsample=None,
        run_tag_filter=None,
    ):
        with timing.log_latency("build request"):
            req = data_provider_pb2.ReadScalarsRequest()
            req.experiment_id = experiment_id
            req.plugin_filter.plugin_name = plugin_name
            _populate_rtf(run_tag_filter, req.run_tag_filter)
            req.downsample.num_points = downsample
        with timing.log_latency("_stub.ReadScalars"):
            with _translate_grpc_error():
                res = self._stub.ReadScalars(req)
        with timing.log_latency("build result"):
            result = {}
            for run_entry in res.runs:
                tags = {}
                result[run_entry.run_name] = tags
                for tag_entry in run_entry.tags:
                    series = []
                    tags[tag_entry.tag_name] = series
                    d = tag_entry.data
                    for (step, wt, value) in zip(d.step, d.wall_time, d.value):
                        point = provider.ScalarDatum(
                            step=step,
                            wall_time=wt,
                            value=value,
                        )
                        series.append(point)
            return result

    @timing.log_latency
    def list_tensors(
        self, ctx, *, experiment_id, plugin_name, run_tag_filter=None
    ):
        with timing.log_latency("build request"):
            req = data_provider_pb2.ListTensorsRequest()
            req.experiment_id = experiment_id
            req.plugin_filter.plugin_name = plugin_name
            _populate_rtf(run_tag_filter, req.run_tag_filter)
        with timing.log_latency("_stub.ListTensors"):
            with _translate_grpc_error():
                res = self._stub.ListTensors(req)
        with timing.log_latency("build result"):
            result = {}
            for run_entry in res.runs:
                tags = {}
                result[run_entry.run_name] = tags
                for tag_entry in run_entry.tags:
                    time_series = tag_entry.metadata
                    tags[tag_entry.tag_name] = provider.TensorTimeSeries(
                        max_step=time_series.max_step,
                        max_wall_time=time_series.max_wall_time,
                        plugin_content=time_series.summary_metadata.plugin_data.content,
                        description=time_series.summary_metadata.summary_description,
                        display_name=time_series.summary_metadata.display_name,
                    )
            return result

    @timing.log_latency
    def read_tensors(
        self,
        ctx,
        *,
        experiment_id,
        plugin_name,
        downsample=None,
        run_tag_filter=None,
    ):
        with timing.log_latency("build request"):
            req = data_provider_pb2.ReadTensorsRequest()
            req.experiment_id = experiment_id
            req.plugin_filter.plugin_name = plugin_name
            _populate_rtf(run_tag_filter, req.run_tag_filter)
            req.downsample.num_points = downsample
        with timing.log_latency("_stub.ReadTensors"):
            with _translate_grpc_error():
                res = self._stub.ReadTensors(req)
        with timing.log_latency("build result"):
            result = {}
            for run_entry in res.runs:
                tags = {}
                result[run_entry.run_name] = tags
                for tag_entry in run_entry.tags:
                    series = []
                    tags[tag_entry.tag_name] = series
                    d = tag_entry.data
                    for (step, wt, value) in zip(d.step, d.wall_time, d.value):
                        point = provider.TensorDatum(
                            step=step,
                            wall_time=wt,
                            numpy=tensor_util.make_ndarray(value),
                        )
                        series.append(point)
            return result

    @timing.log_latency
    def list_blob_sequences(
        self, ctx, experiment_id, plugin_name, run_tag_filter=None
    ):
        with timing.log_latency("build request"):
            req = data_provider_pb2.ListBlobSequencesRequest()
            req.experiment_id = experiment_id
            req.plugin_filter.plugin_name = plugin_name
            _populate_rtf(run_tag_filter, req.run_tag_filter)
        with timing.log_latency("_stub.ListBlobSequences"):
            with _translate_grpc_error():
                res = self._stub.ListBlobSequences(req)
        with timing.log_latency("build result"):
            result = {}
            for run_entry in res.runs:
                tags = {}
                result[run_entry.run_name] = tags
                for tag_entry in run_entry.tags:
                    time_series = tag_entry.metadata
                    tags[tag_entry.tag_name] = provider.BlobSequenceTimeSeries(
                        max_step=time_series.max_step,
                        max_wall_time=time_series.max_wall_time,
                        max_length=time_series.max_length,
                        plugin_content=time_series.summary_metadata.plugin_data.content,
                        description=time_series.summary_metadata.summary_description,
                        display_name=time_series.summary_metadata.display_name,
                    )
            return result

    @timing.log_latency
    def read_blob_sequences(
        self,
        ctx,
        experiment_id,
        plugin_name,
        downsample=None,
        run_tag_filter=None,
    ):
        with timing.log_latency("build request"):
            req = data_provider_pb2.ReadBlobSequencesRequest()
            req.experiment_id = experiment_id
            req.plugin_filter.plugin_name = plugin_name
            _populate_rtf(run_tag_filter, req.run_tag_filter)
            req.downsample.num_points = downsample
        with timing.log_latency("_stub.ReadBlobSequences"):
            with _translate_grpc_error():
                res = self._stub.ReadBlobSequences(req)
        with timing.log_latency("build result"):
            result = {}
            for run_entry in res.runs:
                tags = {}
                result[run_entry.run_name] = tags
                for tag_entry in run_entry.tags:
                    series = []
                    tags[tag_entry.tag_name] = series
                    d = tag_entry.data
                    for (step, wt, blob_sequence) in zip(
                        d.step, d.wall_time, d.values
                    ):
                        values = []
                        for ref in blob_sequence.blob_refs:
                            values.append(
                                provider.BlobReference(
                                    blob_key=ref.blob_key, url=ref.url or None
                                )
                            )
                        point = provider.BlobSequenceDatum(
                            step=step, wall_time=wt, values=tuple(values)
                        )
                        series.append(point)
            return result

    @timing.log_latency
    def read_blob(self, ctx, blob_key):
        with timing.log_latency("build request"):
            req = data_provider_pb2.ReadBlobRequest()
            req.blob_key = blob_key
        with timing.log_latency("list(_stub.ReadBlob)"):
            with _translate_grpc_error():
                responses = list(self._stub.ReadBlob(req))
        with timing.log_latency("build result"):
            return b"".join(res.data for res in responses)


@contextlib.contextmanager
def _translate_grpc_error():
    try:
        yield
    except grpc.RpcError as e:
        if e.code() == grpc.StatusCode.INVALID_ARGUMENT:
            raise errors.InvalidArgumentError(e.details())
        if e.code() == grpc.StatusCode.NOT_FOUND:
            raise errors.NotFoundError(e.details())
        if e.code() == grpc.StatusCode.PERMISSION_DENIED:
            raise errors.PermissionDeniedError(e.details())
        raise


def _populate_rtf(run_tag_filter, rtf_proto):
    """Copies `run_tag_filter` into `rtf_proto`."""
    if run_tag_filter is None:
        return
    if run_tag_filter.runs is not None:
        rtf_proto.runs.names[:] = sorted(run_tag_filter.runs)
    if run_tag_filter.tags is not None:
        rtf_proto.tags.names[:] = sorted(run_tag_filter.tags)


def _timestamp_proto_to_float(ts):
    """Converts `timestamp_pb2.Timestamp` to float seconds since epoch."""
    return ts.ToNanoseconds() / 1e9
