STT-tensorflow/tensorflow/python/kernel_tests/distributions/util_test.py
Xiao Yu f9529caee1 Fallback tf.function with compile_with_xla==true to current TF. And disable a few function tests that are not supported by tfrt.
PiperOrigin-RevId: 336956915
Change-Id: I550ee4736382ee97f43b3b9aeb74aea639adc328
2020-10-13 14:41:16 -07:00

1082 lines
39 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import util as distribution_util
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
du = distribution_util
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
special = try_import("scipy.special")
def _logit(x):
x = np.asarray(x)
return np.log(x) - np.log1p(-x)
class AssertCloseTest(test.TestCase):
@test_util.run_deprecated_v1
def testAssertIntegerForm(self):
# This should only be detected as an integer.
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
# First component isn't less than float32.eps = 1e-7
z = array_ops.placeholder(dtypes.float32)
# This shouldn"t be detected as an integer.
w = array_ops.placeholder(dtypes.float32)
feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
with self.cached_session():
with ops.control_dependencies([du.assert_integer_form(x)]):
array_ops.identity(x).eval(feed_dict=feed_dict)
with self.assertRaisesOpError("has non-integer components"):
with ops.control_dependencies(
[du.assert_integer_form(y)]):
array_ops.identity(y).eval(feed_dict=feed_dict)
with self.assertRaisesOpError("has non-integer components"):
with ops.control_dependencies(
[du.assert_integer_form(z)]):
array_ops.identity(z).eval(feed_dict=feed_dict)
with self.assertRaisesOpError("has non-integer components"):
with ops.control_dependencies(
[du.assert_integer_form(w)]):
array_ops.identity(w).eval(feed_dict=feed_dict)
class MaybeGetStaticTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testGetStaticInt(self):
x = 2
self.assertEqual(x, du.maybe_get_static_value(x))
self.assertAllClose(
np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
@test_util.run_in_graph_and_eager_modes
def testGetStaticNumpyArray(self):
x = np.array(2, dtype=np.int32)
self.assertEqual(x, du.maybe_get_static_value(x))
self.assertAllClose(
np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
@test_util.run_in_graph_and_eager_modes
def testGetStaticConstant(self):
x = constant_op.constant(2, dtype=dtypes.int32)
self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x))
self.assertAllClose(
np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
@test_util.run_deprecated_v1
def testGetStaticPlaceholder(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
self.assertEqual(None, du.maybe_get_static_value(x))
self.assertEqual(None, du.maybe_get_static_value(x, dtype=np.float64))
class GetLogitsAndProbsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testImproperArguments(self):
with self.assertRaises(ValueError):
du.get_logits_and_probs(logits=None, probs=None)
with self.assertRaises(ValueError):
du.get_logits_and_probs(logits=[0.1], probs=[0.1])
@test_util.run_in_graph_and_eager_modes
def testLogits(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
logits = _logit(p)
new_logits, new_p = du.get_logits_and_probs(
logits=logits, validate_args=True)
self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
@test_util.run_in_graph_and_eager_modes
def testLogitsMultidimensional(self):
p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
logits = np.log(p)
new_logits, new_p = du.get_logits_and_probs(
logits=logits, multidimensional=True, validate_args=True)
self.assertAllClose(self.evaluate(new_p), p)
self.assertAllClose(self.evaluate(new_logits), logits)
@test_util.run_in_graph_and_eager_modes
def testProbability(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True)
self.assertAllClose(_logit(p), self.evaluate(new_logits))
self.assertAllClose(p, self.evaluate(new_p))
@test_util.run_in_graph_and_eager_modes
def testProbabilityMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
new_logits, new_p = du.get_logits_and_probs(
probs=p, multidimensional=True, validate_args=True)
self.assertAllClose(np.log(p), self.evaluate(new_logits))
self.assertAllClose(p, self.evaluate(new_p))
@test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgs(self):
p = [0.01, 0.2, 0.5, 0.7, .99]
# Component less than 0.
p2 = [-1, 0.2, 0.5, 0.3, .2]
# Component greater than 1.
p3 = [2, 0.2, 0.5, 0.3, .2]
_, prob = du.get_logits_and_probs(probs=p, validate_args=True)
self.evaluate(prob)
with self.assertRaisesOpError("Condition x >= 0"):
_, prob = du.get_logits_and_probs(probs=p2, validate_args=True)
self.evaluate(prob)
_, prob = du.get_logits_and_probs(probs=p2, validate_args=False)
self.evaluate(prob)
with self.assertRaisesOpError("probs has components greater than 1"):
_, prob = du.get_logits_and_probs(probs=p3, validate_args=True)
self.evaluate(prob)
_, prob = du.get_logits_and_probs(probs=p3, validate_args=False)
self.evaluate(prob)
@test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgsMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
# Component less than 0. Still sums to 1.
p2 = np.array([[-.3, 0.4, 0.9], [0.1, 0.5, 0.4]], dtype=np.float32)
# Component greater than 1. Does not sum to 1.
p3 = np.array([[1.3, 0.0, 0.0], [0.1, 0.5, 0.4]], dtype=np.float32)
# Does not sum to 1.
p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)
_, prob = du.get_logits_and_probs(probs=p, multidimensional=True)
self.evaluate(prob)
with self.assertRaisesOpError("Condition x >= 0"):
_, prob = du.get_logits_and_probs(
probs=p2, multidimensional=True, validate_args=True)
self.evaluate(prob)
_, prob = du.get_logits_and_probs(
probs=p2, multidimensional=True, validate_args=False)
self.evaluate(prob)
with self.assertRaisesOpError(
"(probs has components greater than 1|probs does not sum to 1)"):
_, prob = du.get_logits_and_probs(
probs=p3, multidimensional=True, validate_args=True)
self.evaluate(prob)
_, prob = du.get_logits_and_probs(
probs=p3, multidimensional=True, validate_args=False)
self.evaluate(prob)
with self.assertRaisesOpError("probs does not sum to 1"):
_, prob = du.get_logits_and_probs(
probs=p4, multidimensional=True, validate_args=True)
self.evaluate(prob)
_, prob = du.get_logits_and_probs(
probs=p4, multidimensional=True, validate_args=False)
self.evaluate(prob)
@test_util.run_deprecated_v1
def testProbsMultidimShape(self):
with self.cached_session():
with self.assertRaises(ValueError):
p = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs(
probs=p, multidimensional=True, validate_args=True)
with self.assertRaisesOpError(
"Number of classes exceeds `dtype` precision"):
p = array_ops.placeholder(dtype=dtypes.float16)
_, prob = du.get_logits_and_probs(
probs=p, multidimensional=True, validate_args=True)
prob.eval(feed_dict={p: np.ones([int(2**11+1)])})
@test_util.run_deprecated_v1
def testLogitsMultidimShape(self):
with self.cached_session():
with self.assertRaises(ValueError):
l = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs(
logits=l, multidimensional=True, validate_args=True)
with self.assertRaisesOpError(
"Number of classes exceeds `dtype` precision"):
l = array_ops.placeholder(dtype=dtypes.float16)
logit, _ = du.get_logits_and_probs(
logits=l, multidimensional=True, validate_args=True)
logit.eval(feed_dict={l: np.ones([int(2**11+1)])})
class EmbedCheckCategoricalEventShapeTest(test.TestCase):
@test_util.run_deprecated_v1
def testTooSmall(self):
with self.cached_session():
with self.assertRaises(ValueError):
param = array_ops.ones([1], dtype=np.float16)
checked_param = du.embed_check_categorical_event_shape(
param)
with self.assertRaisesOpError(
"must have at least 2 events"):
param = array_ops.placeholder(dtype=dtypes.float16)
checked_param = du.embed_check_categorical_event_shape(
param)
checked_param.eval(feed_dict={param: np.ones([1])})
@test_util.run_deprecated_v1
def testTooLarge(self):
with self.cached_session():
with self.assertRaises(ValueError):
param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16)
checked_param = du.embed_check_categorical_event_shape(
param)
with self.assertRaisesOpError(
"Number of classes exceeds `dtype` precision"):
param = array_ops.placeholder(dtype=dtypes.float16)
checked_param = du.embed_check_categorical_event_shape(
param)
checked_param.eval(feed_dict={param: np.ones([int(2**11+1)])})
@test_util.disable_tfrt("b/169901260")
@test_util.run_in_graph_and_eager_modes
def testUnsupportedDtype(self):
param = ops.convert_to_tensor(
np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
dtype=dtypes.qint16)
with self.assertRaises(TypeError):
du.embed_check_categorical_event_shape(param)
class EmbedCheckIntegerCastingClosedTest(test.TestCase):
@test_util.run_deprecated_v1
def testCorrectlyAssertsNonnegative(self):
with self.cached_session():
with self.assertRaisesOpError("Elements must be non-negative"):
x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed(
x, target_dtype=dtypes.int16)
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)})
@test_util.run_deprecated_v1
def testCorrectlyAssersIntegerForm(self):
with self.cached_session():
with self.assertRaisesOpError("Elements must be int16-equivalent."):
x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed(
x, target_dtype=dtypes.int16)
x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)})
@test_util.run_deprecated_v1
def testCorrectlyAssertsLargestPossibleInteger(self):
with self.cached_session():
with self.assertRaisesOpError("Elements cannot exceed 32767."):
x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed(
x, target_dtype=dtypes.int16)
x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)})
@test_util.run_deprecated_v1
def testCorrectlyAssertsSmallestPossibleInteger(self):
with self.cached_session():
with self.assertRaisesOpError("Elements cannot be smaller than 0."):
x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed(
x, target_dtype=dtypes.uint16, assert_nonnegative=False)
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.int32)})
@test_util.run_all_in_graph_and_eager_modes
class LogCombinationsTest(test.TestCase):
def testLogCombinationsBinomial(self):
n = [2, 5, 12, 15]
k = [1, 2, 4, 11]
if not special:
return
log_combs = np.log(special.binom(n, k))
n = np.array(n, dtype=np.float32)
counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
log_binom = du.log_combinations(n, counts)
self.assertEqual([4], log_binom.get_shape())
self.assertAllClose(log_combs, self.evaluate(log_binom))
def testLogCombinationsShape(self):
# Shape [2, 2]
n = [[2, 5], [12, 15]]
n = np.array(n, dtype=np.float32)
# Shape [2, 2, 4]
counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
log_binom = du.log_combinations(n, counts)
self.assertEqual([2, 2], log_binom.get_shape())
class DynamicShapeTest(test.TestCase):
@test_util.run_deprecated_v1
def testSameDynamicShape(self):
with self.cached_session():
scalar = constant_op.constant(2.0)
scalar1 = array_ops.placeholder(dtype=dtypes.float32)
vector = [0.3, 0.4, 0.5]
vector1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
vector2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
multidimensional = [[0.3, 0.4], [0.2, 0.6]]
multidimensional1 = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, None])
multidimensional2 = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, None])
# Scalar
self.assertTrue(
du.same_dynamic_shape(scalar, scalar1).eval({
scalar1: 2.0
}))
# Vector
self.assertTrue(
du.same_dynamic_shape(vector, vector1).eval({
vector1: [2.0, 3.0, 4.0]
}))
self.assertTrue(
du.same_dynamic_shape(vector1, vector2).eval({
vector1: [2.0, 3.0, 4.0],
vector2: [2.0, 3.5, 6.0]
}))
# Multidimensional
self.assertTrue(
du.same_dynamic_shape(
multidimensional, multidimensional1).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]
}))
self.assertTrue(
du.same_dynamic_shape(
multidimensional1, multidimensional2).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]],
multidimensional2: [[1.0, 3.5], [6.3, 2.3]]
}))
# Scalar, X
self.assertFalse(
du.same_dynamic_shape(scalar, vector1).eval({
vector1: [2.0, 3.0, 4.0]
}))
self.assertFalse(
du.same_dynamic_shape(scalar1, vector1).eval({
scalar1: 2.0,
vector1: [2.0, 3.0, 4.0]
}))
self.assertFalse(
du.same_dynamic_shape(scalar, multidimensional1).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]
}))
self.assertFalse(
du.same_dynamic_shape(scalar1, multidimensional1).eval(
{
scalar1: 2.0,
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]
}))
# Vector, X
self.assertFalse(
du.same_dynamic_shape(vector, vector1).eval({
vector1: [2.0, 3.0]
}))
self.assertFalse(
du.same_dynamic_shape(vector1, vector2).eval({
vector1: [2.0, 3.0, 4.0],
vector2: [6.0]
}))
self.assertFalse(
du.same_dynamic_shape(vector, multidimensional1).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]
}))
self.assertFalse(
du.same_dynamic_shape(vector1, multidimensional1).eval(
{
vector1: [2.0, 3.0, 4.0],
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]
}))
# Multidimensional, X
self.assertFalse(
du.same_dynamic_shape(
multidimensional, multidimensional1).eval({
multidimensional1: [[1.0, 3.5, 5.0], [6.3, 2.3, 7.1]]
}))
self.assertFalse(
du.same_dynamic_shape(
multidimensional1, multidimensional2).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]],
multidimensional2: [[1.0, 3.5, 5.0], [6.3, 2.3, 7.1]]
}))
class RotateTransposeTest(test.TestCase):
def _np_rotate_transpose(self, x, shift):
if not isinstance(x, np.ndarray):
x = np.array(x)
return np.transpose(x, np.roll(np.arange(len(x.shape)), shift))
@test_util.run_in_graph_and_eager_modes
def testRollStatic(self):
if context.executing_eagerly():
error_message = r"Attempt to convert a value \(None\)"
else:
error_message = "None values not supported."
with self.assertRaisesRegex(ValueError, error_message):
du.rotate_transpose(None, 1)
for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
for shift in np.arange(-5, 5):
y = du.rotate_transpose(x, shift)
self.assertAllEqual(
self._np_rotate_transpose(x, shift), self.evaluate(y))
self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
@test_util.run_deprecated_v1
def testRollDynamic(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
shift = array_ops.placeholder(dtypes.int32)
for x_value in (np.ones(
1, dtype=x.dtype.as_numpy_dtype()), np.ones(
(2, 1), dtype=x.dtype.as_numpy_dtype()), np.ones(
(3, 2, 1), dtype=x.dtype.as_numpy_dtype())):
for shift_value in np.arange(-5, 5):
self.assertAllEqual(
self._np_rotate_transpose(x_value, shift_value),
sess.run(du.rotate_transpose(x, shift),
feed_dict={x: x_value,
shift: shift_value}))
class PickVectorTest(test.TestCase):
@test_util.run_deprecated_v1
def testCorrectlyPicksVector(self):
with self.cached_session():
x = np.arange(10, 12)
y = np.arange(15, 18)
self.assertAllEqual(
x, self.evaluate(du.pick_vector(math_ops.less(0, 5), x, y)))
self.assertAllEqual(
y, self.evaluate(du.pick_vector(math_ops.less(5, 0), x, y)))
self.assertAllEqual(x,
du.pick_vector(
constant_op.constant(True), x, y)) # No eval.
self.assertAllEqual(y,
du.pick_vector(
constant_op.constant(False), x, y)) # No eval.
class PreferStaticRankTest(test.TestCase):
@test_util.run_deprecated_v1
def testNonEmptyConstantTensor(self):
x = array_ops.zeros((2, 3, 4))
rank = du.prefer_static_rank(x)
self.assertIsInstance(rank, np.ndarray)
self.assertEqual(3, rank)
@test_util.run_deprecated_v1
def testEmptyConstantTensor(self):
x = constant_op.constant([])
rank = du.prefer_static_rank(x)
self.assertIsInstance(rank, np.ndarray)
self.assertEqual(1, rank)
@test_util.run_deprecated_v1
def testScalarTensor(self):
x = constant_op.constant(1.)
rank = du.prefer_static_rank(x)
self.assertIsInstance(rank, np.ndarray)
self.assertEqual(0, rank)
@test_util.run_deprecated_v1
def testDynamicRankEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
rank = du.prefer_static_rank(x)
with self.cached_session():
self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))}))
@test_util.run_deprecated_v1
def testDynamicRankEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x)
with self.cached_session():
self.assertAllEqual(1, rank.eval(feed_dict={x: []}))
@test_util.run_deprecated_v1
def testDynamicRankEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x)
with self.cached_session():
self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
class PreferStaticShapeTest(test.TestCase):
@test_util.run_deprecated_v1
def testNonEmptyConstantTensor(self):
x = array_ops.zeros((2, 3, 4))
shape = du.prefer_static_shape(x)
self.assertIsInstance(shape, np.ndarray)
self.assertAllEqual(np.array([2, 3, 4]), shape)
@test_util.run_deprecated_v1
def testEmptyConstantTensor(self):
x = constant_op.constant([])
shape = du.prefer_static_shape(x)
self.assertIsInstance(shape, np.ndarray)
self.assertAllEqual(np.array([0]), shape)
@test_util.run_deprecated_v1
def testScalarTensor(self):
x = constant_op.constant(1.)
shape = du.prefer_static_shape(x)
self.assertIsInstance(shape, np.ndarray)
self.assertAllEqual(np.array([]), shape)
@test_util.run_deprecated_v1
def testDynamicShapeEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
shape = du.prefer_static_shape(x)
with self.cached_session():
self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))}))
@test_util.run_deprecated_v1
def testDynamicShapeEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x)
with self.cached_session():
self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []}))
@test_util.run_deprecated_v1
def testDynamicShapeEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x)
with self.cached_session():
self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
class PreferStaticValueTest(test.TestCase):
@test_util.run_deprecated_v1
def testNonEmptyConstantTensor(self):
x = array_ops.zeros((2, 3, 4))
value = du.prefer_static_value(x)
self.assertIsInstance(value, np.ndarray)
self.assertAllEqual(np.zeros((2, 3, 4)), value)
@test_util.run_deprecated_v1
def testEmptyConstantTensor(self):
x = constant_op.constant([])
value = du.prefer_static_value(x)
self.assertIsInstance(value, np.ndarray)
self.assertAllEqual(np.array([]), value)
@test_util.run_deprecated_v1
def testScalarTensor(self):
x = constant_op.constant(1.)
value = du.prefer_static_value(x)
self.assertIsInstance(value, np.ndarray)
self.assertAllEqual(np.array(1.), value)
@test_util.run_deprecated_v1
def testDynamicValueEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
value = du.prefer_static_value(x)
with self.cached_session():
self.assertAllEqual(np.zeros((2, 3)),
value.eval(feed_dict={x: np.zeros((2, 3))}))
@test_util.run_deprecated_v1
def testDynamicValueEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x)
with self.cached_session():
self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []}))
@test_util.run_deprecated_v1
def testDynamicValueEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x)
with self.cached_session():
self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
class FillTriangularTest(test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
def _fill_triangular(self, x, upper=False):
"""Numpy implementation of `fill_triangular`."""
x = np.asarray(x)
# Formula derived by solving for n: m = n(n+1)/2.
m = np.int32(x.shape[-1])
n = np.sqrt(0.25 + 2. * m) - 0.5
if n != np.floor(n):
raise ValueError("Invalid shape.")
n = np.int32(n)
# We can't do: `x[..., -(n**2-m):]` because this doesn't correctly handle
# `m == n == 1`. Hence, we do absolute indexing.
x_tail = x[..., (m - (n * n - m)):]
y = np.concatenate(
[x, x_tail[..., ::-1]] if upper else [x_tail, x[..., ::-1]],
axis=-1)
y = y.reshape(np.concatenate([
np.int32(x.shape[:-1]),
np.int32([n, n]),
], axis=0))
return np.triu(y) if upper else np.tril(y)
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_)
with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
# Add `zeros_like(x)` such that x's value and gradient are identical. We
# do this so we can ensure each gradient value is mapped to the right
# gradient location. (Not doing this means the gradient wrt `x` is simple
# `ones_like(x)`.)
# Note:
# zeros_like_x_pl == zeros_like(x_pl)
# gradient(zeros_like_x_pl, x_pl) == x_pl - 1
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
- array_ops.stop_gradient(x_pl * (x_pl - 1.)))
x = x_pl + zeros_like_x_pl
actual = du.fill_triangular(x, **kwargs)
grad_actual = gradients_impl.gradients(actual, x_pl)[0]
[actual_, grad_actual_] = sess.run([actual, grad_actual],
feed_dict={x_pl: x_})
expected = self._fill_triangular(x_, **kwargs)
if use_deferred_shape:
self.assertEqual(None, actual.shape)
else:
self.assertAllEqual(expected.shape, actual.shape)
self.assertAllClose(expected, actual_, rtol=1e-8, atol=1e-9)
self.assertAllClose(x_, grad_actual_, rtol=1e-8, atol=1e-9)
@test_util.run_deprecated_v1
def testCorrectlyMakes1x1TriLower(self):
self._run_test(self._rng.randn(3, int(1*2/2)))
@test_util.run_deprecated_v1
def testCorrectlyMakesNoBatchTriLower(self):
self._run_test(self._rng.randn(int(4*5/2)))
@test_util.run_deprecated_v1
def testCorrectlyMakesBatchTriLower(self):
self._run_test(self._rng.randn(2, 3, int(3*4/2)))
@test_util.run_deprecated_v1
def testCorrectlyMakesBatchTriLowerUnknownShape(self):
self._run_test(self._rng.randn(2, 3, int(3*4/2)), use_deferred_shape=True)
@test_util.run_deprecated_v1
def testCorrectlyMakesBatch7x7TriLowerUnknownShape(self):
self._run_test(self._rng.randn(2, 3, int(7*8/2)), use_deferred_shape=True)
@test_util.run_deprecated_v1
def testCorrectlyMakesBatch7x7TriLower(self):
self._run_test(self._rng.randn(2, 3, int(7*8/2)))
@test_util.run_deprecated_v1
def testCorrectlyMakes1x1TriUpper(self):
self._run_test(self._rng.randn(3, int(1*2/2)), upper=True)
@test_util.run_deprecated_v1
def testCorrectlyMakesNoBatchTriUpper(self):
self._run_test(self._rng.randn(int(4*5/2)), upper=True)
@test_util.run_deprecated_v1
def testCorrectlyMakesBatchTriUpper(self):
self._run_test(self._rng.randn(2, 2, int(3*4/2)), upper=True)
@test_util.run_deprecated_v1
def testCorrectlyMakesBatchTriUpperUnknownShape(self):
self._run_test(self._rng.randn(2, 2, int(3*4/2)),
use_deferred_shape=True,
upper=True)
@test_util.run_deprecated_v1
def testCorrectlyMakesBatch7x7TriUpperUnknownShape(self):
self._run_test(self._rng.randn(2, 3, int(7*8/2)),
use_deferred_shape=True,
upper=True)
@test_util.run_deprecated_v1
def testCorrectlyMakesBatch7x7TriUpper(self):
self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True)
class FillTriangularInverseTest(FillTriangularTest):
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_)
with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
- array_ops.stop_gradient(x_pl * (x_pl - 1.)))
x = x_pl + zeros_like_x_pl
actual = du.fill_triangular(x, **kwargs)
inverse_actual = du.fill_triangular_inverse(actual, **kwargs)
inverse_actual_ = sess.run(
inverse_actual,
feed_dict={x_pl: x_})
if use_deferred_shape:
self.assertEqual(None, inverse_actual.shape)
else:
self.assertAllEqual(x_.shape, inverse_actual.shape)
self.assertAllEqual(x_, inverse_actual_)
class ReduceWeightedLogSumExp(test.TestCase):
def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False):
m = np.max(logx, axis=axis, keepdims=True)
sum_ = np.sum(w * np.exp(logx - m), axis=axis, keepdims=keep_dims)
sgn = np.sign(sum_)
if not keep_dims:
m = np.squeeze(m, axis=axis)
return m + np.log(sgn * sum_), sgn
@test_util.run_deprecated_v1
def testNoWeights(self):
logx_ = np.array([[0., -1, 1000.],
[0, 1, -1000.],
[-5, 0, 5]])
with self.cached_session() as sess:
logx = constant_op.constant(logx_)
expected = math_ops.reduce_logsumexp(logx, axis=-1)
grad_expected = gradients_impl.gradients(expected, logx)[0]
actual, actual_sgn = du.reduce_weighted_logsumexp(
logx, axis=-1, return_sign=True)
grad_actual = gradients_impl.gradients(actual, logx)[0]
[actual_, actual_sgn_, grad_actual_,
expected_, grad_expected_] = sess.run([
actual, actual_sgn, grad_actual,
expected, grad_expected])
self.assertAllEqual(expected_, actual_)
self.assertAllEqual(grad_expected_, grad_actual_)
self.assertAllEqual([1., 1, 1], actual_sgn_)
def testNegativeWeights(self):
logx_ = np.array([[0., -1, 1000.],
[0, 1, -1000.],
[-5, 0, 5]])
w_ = np.array([[1., 1, -1],
[1, -2, 1],
[1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1)
with self.cached_session() as sess:
logx = constant_op.constant(logx_)
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
logx, w, axis=-1, return_sign=True)
[actual_, actual_sgn_] = self.evaluate([actual, actual_sgn])
self.assertAllEqual(expected, actual_)
self.assertAllEqual([-1., -1, 1], actual_sgn_)
def testKeepDims(self):
logx_ = np.array([[0., -1, 1000.],
[0, 1, -1000.],
[-5, 0, 5]])
w_ = np.array([[1., 1, -1],
[1, -2, 1],
[1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp(
logx_, w_, axis=-1, keep_dims=True)
with self.cached_session() as sess:
logx = constant_op.constant(logx_)
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
logx, w, axis=-1, return_sign=True, keep_dims=True)
[actual_, actual_sgn_] = self.evaluate([actual, actual_sgn])
self.assertAllEqual(expected, actual_)
self.assertAllEqual([[-1.], [-1], [1]], actual_sgn_)
def testDocString(self):
"""This test verifies the correctness of the docstring examples."""
with self.cached_session():
x = constant_op.constant([[0., 0, 0],
[0, 0, 0]])
w = constant_op.constant([[-1., 1, 1],
[1, 1, 1]])
self.assertAllClose(
np.log(4), self.evaluate(du.reduce_weighted_logsumexp(x, w)))
with np.errstate(divide="ignore"):
self.assertAllClose(
np.log([0, 2, 2]),
self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=0)))
self.assertAllClose(
np.log([1, 3]),
self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=1)))
self.assertAllClose(
np.log([[1], [3]]),
self.evaluate(
du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)))
self.assertAllClose(
np.log(4),
self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=[0, 1])))
class GenNewSeedTest(test.TestCase):
def testOnlyNoneReturnsNone(self):
self.assertFalse(du.gen_new_seed(0, "salt") is None)
self.assertTrue(du.gen_new_seed(None, "salt") is None)
# TODO(jvdillon): Merge this test back into:
# tensorflow/python/kernel_tests/softplus_op_test.py
# once TF core is accepting new ops.
class SoftplusTest(test.TestCase):
def _npSoftplus(self, np_features):
np_features = np.asarray(np_features)
zero = np.asarray(0).astype(np_features.dtype)
return np.logaddexp(zero, np_features)
def _testSoftplus(self, np_features, use_gpu=False):
np_features = np.asarray(np_features)
np_softplus = self._npSoftplus(np_features)
with self.session(use_gpu=use_gpu) as sess:
softplus = nn_ops.softplus(np_features)
softplus_inverse = du.softplus_inverse(softplus)
[tf_softplus, tf_softplus_inverse] = sess.run([
softplus, softplus_inverse])
self.assertAllCloseAccordingToType(np_softplus, tf_softplus)
rtol = {"float16": 0.07, "float32": 0.003, "float64": 0.002}.get(
str(np_features.dtype), 1e-6)
# This will test that we correctly computed the inverse by verifying we
# recovered the original input.
self.assertAllCloseAccordingToType(
np_features, tf_softplus_inverse,
atol=0., rtol=rtol)
self.assertAllEqual(np.ones_like(tf_softplus).astype(np.bool),
tf_softplus > 0)
self.assertShapeEqual(np_softplus, softplus)
self.assertShapeEqual(np_softplus, softplus_inverse)
self.assertAllEqual(np.ones_like(tf_softplus).astype(np.bool),
np.isfinite(tf_softplus))
self.assertAllEqual(np.ones_like(tf_softplus_inverse).astype(np.bool),
np.isfinite(tf_softplus_inverse))
@test_util.run_deprecated_v1
def testNumbers(self):
for t in [np.float16, np.float32, np.float64]:
lower = {np.float16: -15, np.float32: -50, np.float64: -50}.get(t, -100)
upper = {np.float16: 50, np.float32: 50, np.float64: 50}.get(t, 100)
self._testSoftplus(
np.array(np.linspace(lower, upper, int(1e3)).astype(t)).reshape(
[2, -1]),
use_gpu=False)
self._testSoftplus(
np.array(np.linspace(lower, upper, int(1e3)).astype(t)).reshape(
[2, -1]),
use_gpu=True)
log_eps = np.log(np.finfo(t).eps)
one = t(1)
ten = t(10)
self._testSoftplus(
[
log_eps, log_eps - one, log_eps + one, log_eps - ten,
log_eps + ten, -log_eps, -log_eps - one, -log_eps + one,
-log_eps - ten, -log_eps + ten
],
use_gpu=False)
self._testSoftplus(
[
log_eps, log_eps - one, log_eps + one, log_eps - ten,
log_eps + ten - log_eps, -log_eps - one, -log_eps + one,
-log_eps - ten, -log_eps + ten
],
use_gpu=True)
@test_util.run_deprecated_v1
def testGradient(self):
with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
name="x")
y = nn_ops.softplus(x, name="softplus")
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32,
order="F")
err = gradient_checker.compute_gradient_error(
x, [2, 5], y, [2, 5], x_init_value=x_init)
tf_logging.vlog(2, "softplus (float) gradient err = ", err)
self.assertLess(err, 1e-4)
@test_util.run_deprecated_v1
def testInverseSoftplusGradientNeverNan(self):
with self.cached_session():
# Note that this range contains both zero and inf.
x = constant_op.constant(np.logspace(-8, 6).astype(np.float16))
y = du.softplus_inverse(x)
grads = self.evaluate(gradients_impl.gradients(y, x)[0])
# Equivalent to `assertAllFalse` (if it existed).
self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
@test_util.run_deprecated_v1
def testInverseSoftplusGradientFinite(self):
with self.cached_session():
# This range of x is all finite, and so is 1 / x. So the
# gradient and its approximations should be finite as well.
x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))
y = du.softplus_inverse(x)
grads = self.evaluate(gradients_impl.gradients(y, x)[0])
# Equivalent to `assertAllTrue` (if it existed).
self.assertAllEqual(
np.ones_like(grads).astype(np.bool), np.isfinite(grads))
@test_util.run_all_in_graph_and_eager_modes
class ArgumentsTest(test.TestCase):
def testNoArguments(self):
def foo():
return du.parent_frame_arguments()
self.assertEqual({}, foo())
def testPositionalArguments(self):
def foo(a, b, c, d): # pylint: disable=unused-argument
return du.parent_frame_arguments()
self.assertEqual({"a": 1, "b": 2, "c": 3, "d": 4}, foo(1, 2, 3, 4))
# Tests that it does not matter where this function is called, and
# no other local variables are returned back.
def bar(a, b, c):
unused_x = a * b
unused_y = c * 3
return du.parent_frame_arguments()
self.assertEqual({"a": 1, "b": 2, "c": 3}, bar(1, 2, 3))
def testOverloadedArgumentValues(self):
def foo(a, b, c): # pylint: disable=unused-argument
a = 42
b = 31
c = 42
return du.parent_frame_arguments()
self.assertEqual({"a": 42, "b": 31, "c": 42}, foo(1, 2, 3))
def testKeywordArguments(self):
def foo(**kwargs): # pylint: disable=unused-argument
return du.parent_frame_arguments()
self.assertEqual({"a": 1, "b": 2, "c": 3, "d": 4}, foo(a=1, b=2, c=3, d=4))
def testPositionalKeywordArgs(self):
def foo(a, b, c, **kwargs): # pylint: disable=unused-argument
return du.parent_frame_arguments()
self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(a=1, b=2, c=3))
self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
foo(a=1, b=2, c=3, unicorn=None))
def testNoVarargs(self):
def foo(a, b, c, *varargs, **kwargs): # pylint: disable=unused-argument
return du.parent_frame_arguments()
self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(a=1, b=2, c=3))
self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(1, 2, 3, *[1, 2, 3]))
self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
foo(1, 2, 3, unicorn=None))
self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
foo(1, 2, 3, *[1, 2, 3], unicorn=None))
if __name__ == "__main__":
test.main()