# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experiment Data Access API for tensorboard.dev."""


import sys
import time

import grpc

from tensorboard.data.experimental import base_experiment
from tensorboard.data.experimental import utils as experimental_utils
from tensorboard.uploader import auth
from tensorboard.uploader import util
from tensorboard.uploader import server_info as server_info_lib
from tensorboard.uploader.proto import export_service_pb2
from tensorboard.uploader.proto import export_service_pb2_grpc
from tensorboard.uploader.proto import server_info_pb2
from tensorboard.util import grpc_util


DEFAULT_ORIGIN = "https://tensorboard.dev"


def import_pandas():
    """Import pandas, guarded by a user-friendly error message on failure."""
    try:
        import pandas
    except ImportError:
        raise ImportError(
            "The get_scalars() feature requires the pandas package, "
            "which does not seem to be available in your Python "
            "environment. You can install it with command:\n\n"
            "  pip install pandas\n"
        )
    return pandas


class ExperimentFromDev(base_experiment.BaseExperiment):
    """Implementation of BaseExperiment, specialized for tensorboard.dev."""

    def __init__(self, experiment_id, api_endpoint=None):
        """Constructor of ExperimentFromDev.

        Args:
          experiment_id: String ID of the experiment on tensorboard.dev (e.g.,
            "AdYd1TgeTlaLWXx6I8JUbA").
          api_endpoint: Optional override value for API endpoint. Used for
            development only.
        """
        super(ExperimentFromDev, self).__init__()
        self._experiment_id = experiment_id
        self._api_client = get_api_client(api_endpoint=api_endpoint)

    def get_scalars(
        self,
        runs_filter=None,
        tags_filter=None,
        pivot=False,
        include_wall_time=False,
    ):
        # NOTE(#3650): Import pandas early in this method, so if the
        # Python environment does not have pandas installed, an error can be
        # raised early, before any rpc call is made.
        pandas = import_pandas()
        if runs_filter is not None:
            raise NotImplementedError(
                "runs_filter support for get_scalars() is not implemented yet."
            )
        if tags_filter is not None:
            raise NotImplementedError(
                "tags_filter support for get_scalars() is not implemented yet."
            )

        request = export_service_pb2.StreamExperimentDataRequest()
        request.experiment_id = self._experiment_id
        read_time = time.time()
        util.set_timestamp(request.read_timestamp, read_time)
        # TODO(cais, wchargin): Use another rpc to check for staleness and avoid
        # a new StreamExperimentData rpc request if data is not stale.
        stream = self._api_client.StreamExperimentData(
            request, metadata=grpc_util.version_metadata()
        )

        runs = []
        tags = []
        steps = []
        wall_times = []
        values = []
        for response in stream:
            # TODO(cais, wchargin): Display progress bar during data loading.
            num_values = len(response.points.values)
            runs.extend([response.run_name] * num_values)
            tags.extend([response.tag_name] * num_values)
            steps.extend(list(response.points.steps))
            wall_times.extend(
                [t.ToNanoseconds() / 1e9 for t in response.points.wall_times]
            )
            values.extend(list(response.points.values))

        data = {
            "run": runs,
            "tag": tags,
            "step": steps,
            "value": values,
        }
        if include_wall_time:
            data["wall_time"] = wall_times
        dataframe = pandas.DataFrame(data)
        if pivot:
            dataframe = experimental_utils.pivot_dataframe(dataframe)
        return dataframe


def get_api_client(api_endpoint=None):
    server_info = _get_server_info(api_endpoint=api_endpoint)
    _handle_server_info(server_info)
    channel_creds = grpc.ssl_channel_credentials()
    credentials = auth.CredentialsStore().read_credentials()
    if credentials:
        channel_creds = grpc.composite_channel_credentials(
            channel_creds, auth.id_token_call_credentials(credentials)
        )
    channel = grpc.secure_channel(
        server_info.api_server.endpoint, channel_creds
    )
    return export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)


def _get_server_info(api_endpoint=None):
    # TODO(cais): Add more plugins to the list when more plugin/data types
    # are supported
    plugins = ["scalars"]
    if api_endpoint:
        return server_info_lib.create_server_info(
            DEFAULT_ORIGIN, api_endpoint, plugins
        )
    return server_info_lib.fetch_server_info(DEFAULT_ORIGIN, plugins)


def _handle_server_info(info):
    compat = info.compatibility
    if compat.verdict == server_info_pb2.VERDICT_WARN:
        sys.stderr.write("Warning [from server]: %s\n" % compat.details)
        sys.stderr.flush()
    elif compat.verdict == server_info_pb2.VERDICT_ERROR:
        raise ValueError("Error [from server]: %s" % compat.details)
