# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for computing default gradients."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops


def get_zeros_dtype(t):
  """Return the dtype for the default gradient for a Tensor."""
  if t.dtype == dtypes.resource:
    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
    if (handle_data is None or not handle_data.is_set or
        len(handle_data.shape_and_type) != 1):
      raise ValueError("Internal error: Tried to take gradients (or similar) "
                       "of a variable without handle data:\n%s" % str(t))
    return handle_data.shape_and_type[0].dtype
  return t.dtype


def shape_and_dtype(t):
  """Return the shape and dtype for the default gradient for a Tensor."""
  if t.dtype == dtypes.resource:
    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
    if (handle_data is None or not handle_data.is_set or
        len(handle_data.shape_and_type) != 1):
      raise ValueError("Internal error: Tried to take gradients (or similar) "
                       "of a variable without handle data:\n%s" % str(t))
    shape_and_type = handle_data.shape_and_type[0]
    return (tensor_shape.TensorShape(shape_and_type.shape),
            dtypes.as_dtype(shape_and_type.dtype))
  return t.shape, t.dtype


def zeros_like(t):
  """Like array_ops.zeros_like, but respects resource handles."""
  if t.dtype == dtypes.resource:
    return array_ops.zeros(*shape_and_dtype(t))
  else:
    return array_ops.zeros_like(t)


def ones_like(t):
  """Like array_ops.ones_like, but respects resource handles."""
  if t.dtype == dtypes.resource:
    return array_ops.ones(*shape_and_dtype(t))
  else:
    return array_ops.ones_like(t)


def supports_default_grad(t):
  """Whether tensor `t` supports creating a default gradient.

  This function assumes that `t` is of a trainable type.

  Args:
    t: Tensor

  Returns:
    Bool
  """
  if t.dtype == dtypes.resource:
    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
    if (handle_data is None or not handle_data.is_set or
        len(handle_data.shape_and_type) != 1):
      return False
  return True