# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""System for specifying garbage collection (GC) of path based data.

This framework allows for GC of data specified by path names, for example files
on disk.  gc.Path objects each represent a single item stored at a path and may
be a base directory,
  /tmp/exports/0/...
  /tmp/exports/1/...
  ...
or a fully qualified file,
  /tmp/train-1.ckpt
  /tmp/train-2.ckpt
  ...

A gc filter function takes and returns a list of gc.Path items.  Filter
functions are responsible for selecting Path items for preservation or deletion.
Note that functions should always return a sorted list.

For example,
  base_dir = "/tmp"
  # Create the directories.
  for e in xrange(10):
    os.mkdir("%s/%d" % (base_dir, e), 0o755)

  # Create a simple parser that pulls the export_version from the directory.
  path_regex = "^" + re.escape(base_dir) + "/(\\d+)$"
  def parser(path):
    match = re.match(path_regex, path.path)
    if not match:
      return None
    return path._replace(export_version=int(match.group(1)))

  path_list = gc._get_paths("/tmp", parser)  # contains all ten Paths

  every_fifth = gc._mod_export_version(5)
  print(every_fifth(path_list))  # shows ["/tmp/0", "/tmp/5"]

  largest_three = gc.largest_export_versions(3)
  print(largest_three(all_paths))  # shows ["/tmp/7", "/tmp/8", "/tmp/9"]

  both = gc._union(every_fifth, largest_three)
  print(both(all_paths))  # shows ["/tmp/0", "/tmp/5",
                          #        "/tmp/7", "/tmp/8", "/tmp/9"]
  # Delete everything not in 'both'.
  to_delete = gc._negation(both)
  for p in to_delete(all_paths):
    gfile.DeleteRecursively(p.path)  # deletes:  "/tmp/1", "/tmp/2",
                                     # "/tmp/3", "/tmp/4", "/tmp/6",
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import heapq
import math
import os
import tensorflow as tf
from tensorflow.python.platform import gfile

Path = collections.namedtuple('Path', 'path export_version')


def _largest_export_versions(n):
  """Creates a filter that keeps the largest n export versions.

  Args:
    n: number of versions to keep.

  Returns:
    A filter function that keeps the n largest paths.
  """

  def keep(paths):
    heap = []
    for idx, path in enumerate(paths):
      if path.export_version is not None:
        heapq.heappush(heap, (path.export_version, idx))
    keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
    return sorted(keepers)

  return keep


def _one_of_every_n_export_versions(n):
  """Creates a filter that keeps one of every n export versions.

  Args:
    n: interval size.

  Returns:
    A filter function that keeps exactly one path from each interval
    [0, n], (n, 2n], (2n, 3n], etc...  If more than one path exists in an
    interval the largest is kept.
  """

  def keep(paths):
    """A filter function that keeps exactly one out of every n paths."""

    keeper_map = {}  # map from interval to largest path seen in that interval
    for p in paths:
      if p.export_version is None:
        # Skip missing export_versions.
        continue
      # Find the interval (with a special case to map export_version = 0 to
      # interval 0.
      interval = math.floor(
          (p.export_version - 1) / n) if p.export_version else 0
      existing = keeper_map.get(interval, None)
      if (not existing) or (existing.export_version < p.export_version):
        keeper_map[interval] = p
    return sorted(keeper_map.values())

  return keep


def _mod_export_version(n):
  """Creates a filter that keeps every export that is a multiple of n.

  Args:
    n: step size.

  Returns:
    A filter function that keeps paths where export_version % n == 0.
  """

  def keep(paths):
    keepers = []
    for p in paths:
      if p.export_version % n == 0:
        keepers.append(p)
    return sorted(keepers)

  return keep


def _union(lf, rf):
  """Creates a filter that keeps the union of two filters.

  Args:
    lf: first filter
    rf: second filter

  Returns:
    A filter function that keeps the n largest paths.
  """

  def keep(paths):
    l = set(lf(paths))
    r = set(rf(paths))
    return sorted(list(l | r))

  return keep


def _negation(f):
  """Negate a filter.

  Args:
    f: filter function to invert

  Returns:
    A filter function that returns the negation of f.
  """

  def keep(paths):
    l = set(paths)
    r = set(f(paths))
    return sorted(list(l - r))

  return keep


def _get_paths(base_dir, parser):
  """Gets a list of Paths in a given directory.

  Args:
    base_dir: directory.
    parser: a function which gets the raw Path and can augment it with
      information such as the export_version, or ignore the path by returning
      None.  An example parser may extract the export version from a path such
      as "/tmp/exports/100" an another may extract from a full file name such as
      "/tmp/checkpoint-99.out".

  Returns:
    A list of Paths contained in the base directory with the parsing function
    applied.
    By default the following fields are populated,
      - Path.path
    The parsing function is responsible for populating,
      - Path.export_version
  """
  # We are mocking this in the test, hence we should not use public API
  raw_paths = gfile.ListDirectory(base_dir)
  paths = []
  for r in raw_paths:
    # ListDirectory() return paths with "/" at the last if base_dir was GCS URL
    r = tf.compat.as_str_any(r)
    if r[-1] == '/':
      r = r[0:len(r) - 1]
    p = parser(Path(os.path.join(tf.compat.as_str_any(base_dir), r), None))
    if p:
      paths.append(p)
  return sorted(paths)
