# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Beta distribution class."""

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export


__all__ = [
    "Beta",
    "BetaWithSoftplusConcentration",
]


_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""


@tf_export(v1=["distributions.Beta"])
class Beta(distribution.Distribution):
  """Beta distribution.

  The Beta distribution is defined over the `(0, 1)` interval using parameters
  `concentration1` (aka "alpha") and `concentration0` (aka "beta").

  #### Mathematical Details

  The probability density function (pdf) is,

  ```none
  pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
  Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
  ```

  where:

  * `concentration1 = alpha`,
  * `concentration0 = beta`,
  * `Z` is the normalization constant, and,
  * `Gamma` is the [gamma function](
    https://en.wikipedia.org/wiki/Gamma_function).

  The concentration parameters represent mean total counts of a `1` or a `0`,
  i.e.,

  ```none
  concentration1 = alpha = mean * total_concentration
  concentration0 = beta  = (1. - mean) * total_concentration
  ```

  where `mean` in `(0, 1)` and `total_concentration` is a positive real number
  representing a mean `total_count = concentration1 + concentration0`.

  Distribution parameters are automatically broadcast in all functions; see
  examples for details.

  Warning: The samples can be zero due to finite precision.
  This happens more often when some of the concentrations are very small.
  Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
  density.

  Samples of this distribution are reparameterized (pathwise differentiable).
  The derivatives are computed using the approach described in
  (Figurnov et al., 2018).

  #### Examples

  ```python
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  # Create a batch of three Beta distributions.
  alpha = [1, 2, 3]
  beta = [1, 2, 3]
  dist = tfd.Beta(alpha, beta)

  dist.sample([4, 5])  # Shape [4, 5, 3]

  # `x` has three batch entries, each with two samples.
  x = [[.1, .4, .5],
       [.2, .3, .5]]
  # Calculate the probability of each pair of samples under the corresponding
  # distribution in `dist`.
  dist.prob(x)         # Shape [2, 3]
  ```

  ```python
  # Create batch_shape=[2, 3] via parameter broadcast:
  alpha = [[1.], [2]]      # Shape [2, 1]
  beta = [3., 4, 5]        # Shape [3]
  dist = tfd.Beta(alpha, beta)

  # alpha broadcast as: [[1., 1, 1,],
  #                      [2, 2, 2]]
  # beta broadcast as:  [[3., 4, 5],
  #                      [3, 4, 5]]
  # batch_Shape [2, 3]
  dist.sample([4, 5])  # Shape [4, 5, 2, 3]

  x = [.2, .3, .5]
  # x will be broadcast as [[.2, .3, .5],
  #                         [.2, .3, .5]],
  # thus matching batch_shape [2, 3].
  dist.prob(x)         # Shape [2, 3]
  ```

  Compute the gradients of samples w.r.t. the parameters:

  ```python
  alpha = tf.constant(1.0)
  beta = tf.constant(2.0)
  dist = tfd.Beta(alpha, beta)
  samples = dist.sample(5)  # Shape [5]
  loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
  # Unbiased stochastic gradients of the loss function
  grads = tf.gradients(loss, [alpha, beta])
  ```

  References:
    Implicit Reparameterization Gradients:
      [Figurnov et al., 2018]
      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
      ([pdf]
      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
  """

  @deprecation.deprecated(
      "2019-01-01",
      "The TensorFlow Distributions library has moved to "
      "TensorFlow Probability "
      "(https://github.com/tensorflow/probability). You "
      "should update all references to use `tfp.distributions` "
      "instead of `tf.distributions`.",
      warn_once=True)
  def __init__(self,
               concentration1=None,
               concentration0=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Beta"):
    """Initialize a batch of Beta distributions.

    Args:
      concentration1: Positive floating-point `Tensor` indicating mean
        number of successes; aka "alpha". Implies `self.dtype` and
        `self.batch_shape`, i.e.,
        `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
      concentration0: Positive floating-point `Tensor` indicating mean
        number of failures; aka "beta". Otherwise has same semantics as
        `concentration1`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    with ops.name_scope(name, values=[concentration1, concentration0]) as name:
      self._concentration1 = self._maybe_assert_valid_concentration(
          ops.convert_to_tensor(concentration1, name="concentration1"),
          validate_args)
      self._concentration0 = self._maybe_assert_valid_concentration(
          ops.convert_to_tensor(concentration0, name="concentration0"),
          validate_args)
      check_ops.assert_same_float_dtype([
          self._concentration1, self._concentration0])
      self._total_concentration = self._concentration1 + self._concentration0
    super(Beta, self).__init__(
        dtype=self._total_concentration.dtype,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
        parameters=parameters,
        graph_parents=[self._concentration1,
                       self._concentration0,
                       self._total_concentration],
        name=name)

  @staticmethod
  def _param_shapes(sample_shape):
    return dict(zip(
        ["concentration1", "concentration0"],
        [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))

  @property
  def concentration1(self):
    """Concentration parameter associated with a `1` outcome."""
    return self._concentration1

  @property
  def concentration0(self):
    """Concentration parameter associated with a `0` outcome."""
    return self._concentration0

  @property
  def total_concentration(self):
    """Sum of concentration parameters."""
    return self._total_concentration

  def _batch_shape_tensor(self):
    return array_ops.shape(self.total_concentration)

  def _batch_shape(self):
    return self.total_concentration.get_shape()

  def _event_shape_tensor(self):
    return constant_op.constant([], dtype=dtypes.int32)

  def _event_shape(self):
    return tensor_shape.TensorShape([])

  def _sample_n(self, n, seed=None):
    expanded_concentration1 = array_ops.ones_like(
        self.total_concentration, dtype=self.dtype) * self.concentration1
    expanded_concentration0 = array_ops.ones_like(
        self.total_concentration, dtype=self.dtype) * self.concentration0
    gamma1_sample = random_ops.random_gamma(
        shape=[n],
        alpha=expanded_concentration1,
        dtype=self.dtype,
        seed=seed)
    gamma2_sample = random_ops.random_gamma(
        shape=[n],
        alpha=expanded_concentration0,
        dtype=self.dtype,
        seed=distribution_util.gen_new_seed(seed, "beta"))
    beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
    return beta_sample

  @distribution_util.AppendDocstring(_beta_sample_note)
  def _log_prob(self, x):
    return self._log_unnormalized_prob(x) - self._log_normalization()

  @distribution_util.AppendDocstring(_beta_sample_note)
  def _prob(self, x):
    return math_ops.exp(self._log_prob(x))

  @distribution_util.AppendDocstring(_beta_sample_note)
  def _log_cdf(self, x):
    return math_ops.log(self._cdf(x))

  @distribution_util.AppendDocstring(_beta_sample_note)
  def _cdf(self, x):
    return math_ops.betainc(self.concentration1, self.concentration0, x)

  def _log_unnormalized_prob(self, x):
    x = self._maybe_assert_valid_sample(x)
    return (math_ops.xlogy(self.concentration1 - 1., x) +
            (self.concentration0 - 1.) * math_ops.log1p(-x))  # pylint: disable=invalid-unary-operand-type

  def _log_normalization(self):
    return (math_ops.lgamma(self.concentration1)
            + math_ops.lgamma(self.concentration0)
            - math_ops.lgamma(self.total_concentration))

  def _entropy(self):
    return (
        self._log_normalization()
        - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1)
        - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0)
        + ((self.total_concentration - 2.) *
           math_ops.digamma(self.total_concentration)))

  def _mean(self):
    return self._concentration1 / self._total_concentration

  def _variance(self):
    return self._mean() * (1. - self._mean()) / (1. + self.total_concentration)

  @distribution_util.AppendDocstring(
      """Note: The mode is undefined when `concentration1 <= 1` or
      `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
      is used for undefined modes. If `self.allow_nan_stats` is `False` an
      exception is raised when one or more modes are undefined.""")
  def _mode(self):
    mode = (self.concentration1 - 1.) / (self.total_concentration - 2.)
    if self.allow_nan_stats:
      nan = array_ops.fill(
          self.batch_shape_tensor(),
          np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
          name="nan")
      is_defined = math_ops.logical_and(self.concentration1 > 1.,
                                        self.concentration0 > 1.)
      return array_ops.where_v2(is_defined, mode, nan)
    return control_flow_ops.with_dependencies([
        check_ops.assert_less(
            array_ops.ones([], dtype=self.dtype),
            self.concentration1,
            message="Mode undefined for concentration1 <= 1."),
        check_ops.assert_less(
            array_ops.ones([], dtype=self.dtype),
            self.concentration0,
            message="Mode undefined for concentration0 <= 1.")
    ], mode)

  def _maybe_assert_valid_concentration(self, concentration, validate_args):
    """Checks the validity of a concentration parameter."""
    if not validate_args:
      return concentration
    return control_flow_ops.with_dependencies([
        check_ops.assert_positive(
            concentration,
            message="Concentration parameter must be positive."),
    ], concentration)

  def _maybe_assert_valid_sample(self, x):
    """Checks the validity of a sample."""
    if not self.validate_args:
      return x
    return control_flow_ops.with_dependencies([
        check_ops.assert_positive(x, message="sample must be positive"),
        check_ops.assert_less(
            x,
            array_ops.ones([], self.dtype),
            message="sample must be less than `1`."),
    ], x)


class BetaWithSoftplusConcentration(Beta):
  """Beta with softplus transform of `concentration1` and `concentration0`."""

  @deprecation.deprecated(
      "2019-01-01",
      "Use `tfd.Beta(tf.nn.softplus(concentration1), "
      "tf.nn.softplus(concentration2))` instead.",
      warn_once=True)
  def __init__(self,
               concentration1,
               concentration0,
               validate_args=False,
               allow_nan_stats=True,
               name="BetaWithSoftplusConcentration"):
    parameters = dict(locals())
    with ops.name_scope(name, values=[concentration1,
                                      concentration0]) as name:
      super(BetaWithSoftplusConcentration, self).__init__(
          concentration1=nn.softplus(concentration1,
                                     name="softplus_concentration1"),
          concentration0=nn.softplus(concentration0,
                                     name="softplus_concentration0"),
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=name)
    self._parameters = parameters


@kullback_leibler.RegisterKL(Beta, Beta)
def _kl_beta_beta(d1, d2, name=None):
  """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta.

  Args:
    d1: instance of a Beta distribution object.
    d2: instance of a Beta distribution object.
    name: (optional) Name to use for created operations.
      default is "kl_beta_beta".

  Returns:
    Batchwise KL(d1 || d2)
  """
  def delta(fn, is_property=True):
    fn1 = getattr(d1, fn)
    fn2 = getattr(d2, fn)
    return (fn2 - fn1) if is_property else (fn2() - fn1())
  with ops.name_scope(name, "kl_beta_beta", values=[
      d1.concentration1,
      d1.concentration0,
      d1.total_concentration,
      d2.concentration1,
      d2.concentration0,
      d2.total_concentration,
  ]):
    return (delta("_log_normalization", is_property=False)
            - math_ops.digamma(d1.concentration1) * delta("concentration1")
            - math_ops.digamma(d1.concentration0) * delta("concentration0")
            + (math_ops.digamma(d1.total_concentration)
               * delta("total_concentration")))
