# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Progress tracker for uploader."""


import contextlib
from datetime import datetime
import sys
import time


def readable_time_string():
    """Get a human-readable time string for the present."""
    return datetime.now().strftime("%Y-%m-%dT%H:%M:%S")


def readable_bytes_string(bytes):
    """Get a human-readable string for number of bytes."""
    if bytes >= 2 ** 20:
        return "%.1f MB" % (float(bytes) / 2 ** 20)
    elif bytes >= 2 ** 10:
        return "%.1f kB" % (float(bytes) / 2 ** 10)
    else:
        return "%d B" % bytes


class UploadStats(object):
    """Statistics of uploading."""

    def __init__(self):
        self._last_summarized_timestamp = time.time()
        self._last_data_added_timestamp = 0
        self._num_scalars = 0
        self._num_tensors = 0
        self._num_tensors_skipped = 0
        self._tensor_bytes = 0
        self._tensor_bytes_skipped = 0
        self._num_blobs = 0
        self._num_blobs_skipped = 0
        self._blob_bytes = 0
        self._blob_bytes_skipped = 0
        self._plugin_names = set()

    def add_scalars(self, num_scalars):
        """Add a batch of scalars.

        Args:
          num_scalars: Number of scalars uploaded in this batch.
        """
        self._refresh_last_data_added_timestamp()
        self._num_scalars += num_scalars

    def add_tensors(
        self,
        num_tensors,
        num_tensors_skipped,
        tensor_bytes,
        tensor_bytes_skipped,
    ):
        """Add a batch of tensors.

        Args:
          num_tensors: Number of tensors encountered in this batch, including
            the ones skipped due to reasons such as large exceeding limit.
          num_tensors: Number of tensors skipped. This describes a subset of
            `num_tensors` and hence must be `<= num_tensors`.
          tensor_bytes: Total byte size of tensors encountered in this batch,
            including the skipped ones.
          tensor_bytes_skipped: Total byte size of the tensors skipped due to
            reasons such as size exceeding limit.
        """
        assert num_tensors_skipped <= num_tensors
        assert tensor_bytes_skipped <= tensor_bytes
        self._refresh_last_data_added_timestamp()
        self._num_tensors += num_tensors
        self._num_tensors_skipped += num_tensors_skipped
        self._tensor_bytes += tensor_bytes
        self._tensor_bytes_skipped = tensor_bytes_skipped

    def add_blob(self, blob_bytes, is_skipped):
        """Add a blob.

        Args:
          blob_bytes: Byte size of the blob.
          is_skipped: Whether the uploading of the blob is skipped due to
            reasons such as size exceeding limit.
        """
        self._refresh_last_data_added_timestamp()
        self._num_blobs += 1
        self._blob_bytes += blob_bytes
        if is_skipped:
            self._num_blobs_skipped += 1
            self._blob_bytes_skipped += blob_bytes

    def add_plugin(self, plugin_name):
        """Add a plugin.

        Args:
          plugin_name: Name of the plugin.
        """
        self._refresh_last_data_added_timestamp()
        self._plugin_names.add(plugin_name)

    @property
    def num_scalars(self):
        return self._num_scalars

    @property
    def num_tensors(self):
        return self._num_tensors

    @property
    def num_tensors_skipped(self):
        return self._num_tensors_skipped

    @property
    def tensor_bytes(self):
        return self._tensor_bytes

    @property
    def tensor_bytes_skipped(self):
        return self._tensor_bytes_skipped

    @property
    def num_blobs(self):
        return self._num_blobs

    @property
    def num_blobs_skipped(self):
        return self._num_blobs_skipped

    @property
    def blob_bytes(self):
        return self._blob_bytes

    @property
    def blob_bytes_skipped(self):
        return self._blob_bytes_skipped

    @property
    def plugin_names(self):
        return self._plugin_names

    def has_data(self):
        """Has any data been tracked by this instance.

        This counts the tensor and blob data that have been scanned
        but skipped.

        Returns:
          Whether this stats tracking object has tracked any data.
        """
        return (
            self._num_scalars > 0
            or self._num_tensors > 0
            or self._num_blobs > 0
        )

    def summarize(self):
        """Get a summary string for actually-uploaded and skipped data.

        Calling this property also marks the "last_summarized" timestamp, so that
        the has_new_data_since_last_summarize() will be able to report the correct value
        later.

        Returns:
          A tuple with two items:
          - A string summarizing all data uploaded so far.
          - If any data was skipped, a string for all skipped data. Else, `None`.
        """
        self._last_summarized_timestamp = time.time()
        string_pieces = []
        string_pieces.append("%d scalars" % self._num_scalars)
        uploaded_tensor_count = self._num_tensors - self._num_tensors_skipped
        uploaded_tensor_bytes = self._tensor_bytes - self._tensor_bytes_skipped
        string_pieces.append(
            "0 tensors"
            if not uploaded_tensor_count
            else (
                "%d tensors (%s)"
                % (
                    uploaded_tensor_count,
                    readable_bytes_string(uploaded_tensor_bytes),
                )
            )
        )
        uploaded_blob_count = self._num_blobs - self._num_blobs_skipped
        uploaded_blob_bytes = self._blob_bytes - self._blob_bytes_skipped
        string_pieces.append(
            "0 binary objects"
            if not uploaded_blob_count
            else (
                "%d binary objects (%s)"
                % (
                    uploaded_blob_count,
                    readable_bytes_string(uploaded_blob_bytes),
                )
            )
        )
        skipped_string = (
            self._skipped_summary() if self._skipped_any() else None
        )
        return ", ".join(string_pieces), skipped_string

    def _skipped_any(self):
        """Whether any data was skipped."""
        return self._num_tensors_skipped or self._num_blobs_skipped

    def has_new_data_since_last_summarize(self):
        return self._last_data_added_timestamp > self._last_summarized_timestamp

    def _skipped_summary(self):
        """Get a summary string for skipped data."""
        string_pieces = []
        if self._num_tensors_skipped:
            string_pieces.append(
                "%d tensors (%s)"
                % (
                    self._num_tensors_skipped,
                    readable_bytes_string(self._tensor_bytes_skipped),
                )
            )
        if self._num_blobs_skipped:
            string_pieces.append(
                "%d binary objects (%s)"
                % (
                    self._num_blobs_skipped,
                    readable_bytes_string(self._blob_bytes_skipped),
                )
            )
        return ", ".join(string_pieces)

    def _refresh_last_data_added_timestamp(self):
        self._last_data_added_timestamp = time.time()


_STYLE_RESET = "\033[0m"
_STYLE_BOLD = "\033[1m"
_STYLE_GREEN = "\033[32m"
_STYLE_YELLOW = "\033[33m"
_STYLE_DARKGRAY = "\033[90m"
_STYLE_ERASE_LINE = "\033[2K"


class UploadTracker(object):
    """Tracker for uploader progress and status."""

    _SUPPORTED_VERBISITY_VALUES = (0, 1)

    def __init__(self, verbosity, one_shot=False):
        if verbosity not in self._SUPPORTED_VERBISITY_VALUES:
            raise ValueError(
                "Unsupported verbosity value %s (supported values: %s)"
                % (verbosity, self._SUPPORTED_VERBISITY_VALUES)
            )
        self._verbosity = verbosity
        self._stats = UploadStats()
        self._send_count = 0
        self._one_shot = one_shot

    def _dummy_generator(self):
        while True:
            # Yield an arbitrary value 0: The progress bar is indefinite.
            yield 0

    def _overwrite_line_message(self, message, color_code=_STYLE_GREEN):
        """Overwrite the current line with a stylized message."""
        if not self._verbosity:
            return
        message += "." * 3
        sys.stdout.write(
            _STYLE_ERASE_LINE + color_code + message + _STYLE_RESET + "\r"
        )
        sys.stdout.flush()

    def _single_line_message(self, message):
        """Write a timestamped single line, with newline, to stdout."""
        if not self._verbosity:
            return
        start_message = "%s[%s]%s %s\n" % (
            _STYLE_BOLD,
            readable_time_string(),
            _STYLE_RESET,
            message,
        )
        sys.stdout.write(start_message)
        sys.stdout.flush()

    def has_data(self):
        """Determine if any data has been uploaded under the tracker's watch."""
        return self._stats.has_data()

    def _update_cumulative_status(self):
        """Write an update summarizing the data uploaded since the start."""
        if not self._verbosity:
            return
        if not self._stats.has_new_data_since_last_summarize():
            return
        uploaded_str, skipped_str = self._stats.summarize()
        uploaded_message = "%s[%s]%s Total uploaded: %s\n" % (
            _STYLE_BOLD,
            readable_time_string(),
            _STYLE_RESET,
            uploaded_str,
        )
        sys.stdout.write(uploaded_message)
        if skipped_str:
            sys.stdout.write(
                "%sTotal skipped: %s\n%s"
                % (_STYLE_DARKGRAY, skipped_str, _STYLE_RESET)
            )
        sys.stdout.flush()
        # TODO(cais): Add summary of what plugins have been involved, once it's
        # clear how to get canonical plugin names.

    def add_plugin_name(self, plugin_name):
        self._stats.add_plugin(plugin_name)

    @contextlib.contextmanager
    def send_tracker(self):
        """Create a context manager for a round of data sending."""
        self._send_count += 1
        if self._send_count == 1:
            self._single_line_message("Started scanning logdir.")
        try:
            # self._reset_bars()
            self._overwrite_line_message("Data upload starting")
            yield
        finally:
            self._update_cumulative_status()
            if self._one_shot:
                self._single_line_message("Done scanning logdir.")
            else:
                self._overwrite_line_message(
                    "Listening for new data in logdir",
                    color_code=_STYLE_YELLOW,
                )

    @contextlib.contextmanager
    def scalars_tracker(self, num_scalars):
        """Create a context manager for tracking a scalar batch upload.

        Args:
          num_scalars: Number of scalars in the batch.
        """
        self._overwrite_line_message("Uploading %d scalars" % num_scalars)
        try:
            yield
        finally:
            self._stats.add_scalars(num_scalars)

    @contextlib.contextmanager
    def tensors_tracker(
        self,
        num_tensors,
        num_tensors_skipped,
        tensor_bytes,
        tensor_bytes_skipped,
    ):
        """Create a context manager for tracking a tensor batch upload.

        Args:
          num_tensors: Total number of tensors in the batch.
          num_tensors_skipped: Number of tensors skipped (a subset of
            `num_tensors`). Hence this must be `<= num_tensors`.
          tensor_bytes: Total byte size of the tensors in the batch.
          tensor_bytes_skipped: Byte size of skipped tensors in the batch (a
            subset of `tensor_bytes`). Must be `<= tensor_bytes`.
        """
        if num_tensors_skipped:
            message = "Uploading %d tensors (%s) (Skipping %d tensors, %s)" % (
                num_tensors - num_tensors_skipped,
                readable_bytes_string(tensor_bytes - tensor_bytes_skipped),
                num_tensors_skipped,
                readable_bytes_string(tensor_bytes_skipped),
            )
        else:
            message = "Uploading %d tensors (%s)" % (
                num_tensors,
                readable_bytes_string(tensor_bytes),
            )
        self._overwrite_line_message(message)
        try:
            yield
        finally:
            self._stats.add_tensors(
                num_tensors,
                num_tensors_skipped,
                tensor_bytes,
                tensor_bytes_skipped,
            )

    @contextlib.contextmanager
    def blob_tracker(self, blob_bytes):
        """Creates context manager tracker for uploading a blob.

        Args:
          blob_bytes: Total byte size of the blob being uploaded.
        """
        self._overwrite_line_message(
            "Uploading binary object (%s)" % readable_bytes_string(blob_bytes)
        )
        try:
            yield _BlobTracker(self._stats, blob_bytes)
        finally:
            pass


class _BlobTracker(object):
    def __init__(self, upload_stats, blob_bytes):
        self._upload_stats = upload_stats
        self._blob_bytes = blob_bytes

    def mark_uploaded(self, is_uploaded):
        self._upload_stats.add_blob(
            self._blob_bytes, is_skipped=(not is_uploaded)
        )
