Merge pull request #23109 from frreiss:issue-assert-refactor-2

PiperOrigin-RevId: 261230566
This commit is contained in:
TensorFlower Gardener 2019-08-01 19:15:57 -07:00
commit c52b412821
16 changed files with 553 additions and 547 deletions

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -461,13 +462,14 @@ class AffineBijectorTest(test.TestCase):
def testNoBatchMultivariateRaisesWhenSingular(self):
with self.cached_session():
mu = [1., -1]
bijector = Affine(
shift=mu,
# Has zero on the diagonal.
scale_diag=[0., 1],
validate_args=True)
with self.assertRaisesOpError("diagonal part must be non-zero"):
bijector.forward([1., 1.]).eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"diagonal part must be non-zero"):
_ = Affine(
shift=mu,
# Has zero on the diagonal.
scale_diag=[0., 1],
validate_args=True)
# Error detected statically; don't need to run the op.
def _makeScale(self,
x,

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
@ -150,6 +151,27 @@ class _ReshapeBijectorTest(object):
with self.assertRaisesError(expected_error_message):
sess.run(bijector.forward_event_shape_tensor(shape_in),
feed_dict=feed_dict)
def _testInvalidDimensionsStatic(self, expected_error_message):
"""Version of _testInvalidDimensionsOpError for errors detected statically.
Statically means at graph construction time.
Args:
expected_error_message: String that should be present in the error
message that `Reshape` raises for invalid shapes.
"""
shape_in, shape_out, _ = self.build_shapes([2, 3], [
1,
2,
-2,
])
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
expected_error_message):
_ = Reshape(
event_shape_out=shape_out,
event_shape_in=shape_in,
validate_args=True)
# pylint: enable=invalid-name
def testValidButNonMatchingInputOpError(self):
@ -300,9 +322,9 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
assert_bijective_and_finite(
bijector, x, y, event_ndims=2, rtol=1e-6, atol=0)
def testInvalidDimensionsOpError(self):
self._testInvalidDimensionsOpError(
"Invalid value in tensor used for shape: -2")
def testInvalidDimensionsStatic(self):
self._testInvalidDimensionsStatic(
"elements must be either positive integers or `-1`")
def testInputOutputMismatchOpError(self):
self._testInputOutputMismatchOpError("Cannot reshape a tensor with")

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import errors
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
@ -43,9 +44,10 @@ class SoftplusBijectorTest(test.TestCase):
def testHingeSoftnessZeroRaises(self):
with self.cached_session():
bijector = Softplus(hinge_softness=0., validate_args=True)
with self.assertRaisesOpError("must be non-zero"):
bijector.forward([1., 1.]).eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"must be non-zero"):
_ = Softplus(hinge_softness=0., validate_args=True)
# Error detected statically; don't need to run op.
def testBijectorForwardInverseEventDimsZero(self):
with self.cached_session():

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import cauchy as cauchy_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@ -400,9 +401,10 @@ class CauchyTest(test.TestCase):
def testCauchyNegativeLocFails(self):
with self.cached_session():
cauchy = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True)
with self.assertRaisesOpError("Condition x > 0 did not hold"):
cauchy.mode().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"Condition x > 0 did not hold"):
_ = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True)
# Error detected statically; no need for _.mode().eval()
def testCauchyShape(self):
with self.cached_session():

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -40,11 +41,10 @@ class DeterministicTest(test.TestCase):
def testInvalidTolRaises(self):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.Deterministic(
loc, atol=-1, validate_args=True)
with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
deterministic.prob(0.).eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Condition x >= 0"):
_ = deterministic_lib.Deterministic(loc, atol=-1, validate_args=True)
# Error detected statically; no need for _.prob(0.).eval()
def testProbWithNoBatchDimsIntegerType(self):
deterministic = deterministic_lib.Deterministic(0)
@ -195,16 +195,16 @@ class VectorDeterministicTest(test.TestCase):
def testInvalidTolRaises(self):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
deterministic.prob(loc).eval()
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Condition x >= 0"):
_ = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
# Error detected statically; no need for _.prob(loc).eval()
def testInvalidXRaises(self):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
loc, atol=None, validate_args=True)
with self.cached_session():
with self.assertRaisesRegexp(ValueError, "must have rank at least 1"):
deterministic.prob(0.).eval()

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@ -41,6 +42,7 @@ def try_import(name): # pylint: disable=invalid-name
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
@ -288,9 +290,10 @@ class HalfNormalTest(test.TestCase):
def testNegativeSigmaFails(self):
with self.cached_session():
halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G")
with self.assertRaisesOpError("Condition x > 0 did not hold"):
halfnorm.mean().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"Condition x > 0 did not hold"):
_ = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G")
# Error detected statically; no need for _.mean().eval()
def testHalfNormalShape(self):
with self.cached_session():

View File

@ -22,6 +22,7 @@ from scipy import stats
from tensorflow.contrib.distributions.python.ops import inverse_gamma
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
@ -249,7 +250,8 @@ class InverseGammaTest(test.TestCase):
fails += 0 if self._kstest(a, b, s) else 1
self.assertLess(fails, trials * 0.03)
def _kstest(self, alpha, beta, samples):
@staticmethod
def _kstest(alpha, beta, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = stats.kstest(samples, stats.invgamma(alpha, scale=beta).cdf)
# Return True when the test passes.
@ -295,16 +297,18 @@ class InverseGammaTest(test.TestCase):
with self.cached_session():
alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta")
inv_gamma = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
with self.assertRaisesOpError("alpha"):
inv_gamma.mean().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"alpha"):
_ = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
# Error detected statically; no need for _.mean().eval()
alpha_v = constant_op.constant(1.0, name="alpha")
beta_v = constant_op.constant(0.0, name="beta")
inv_gamma = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
with self.assertRaisesOpError("beta"):
inv_gamma.mean().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"beta"):
_ = inverse_gamma.InverseGamma(
concentration=alpha_v, rate=beta_v, validate_args=True)
# Error detected statically; no need for _.mean().eval()
def testInverseGammaWithSoftplusConcentrationRate(self):
with self.cached_session():

View File

@ -21,6 +21,7 @@ import numpy as np
from scipy import stats
from tensorflow.contrib import distributions as distributions_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
@ -361,15 +362,14 @@ class QuantizedDistributionTest(test.TestCase):
def testLowerCutoffMustBeBelowUpperCutoffOrWeRaise(self):
with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
low=1., # not strictly less than high.
high=1.,
validate_args=True)
self.assertTrue(qdist.validate_args) # Default is True.
with self.assertRaisesOpError("must be strictly less"):
qdist.sample().eval()
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"must be strictly less"):
_ = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
low=1., # not strictly less than high.
high=1.,
validate_args=True)
# Error detected statically; no need for _.sample().eval()
def testCutoffsMustBeIntegerValuedIfValidateArgsTrue(self):
with self.cached_session():

View File

@ -94,12 +94,11 @@ class RelaxedBernoulliTest(test.TestCase):
"""If validate_args, raises InvalidArgumentError when temperature is 0."""
temperature = constant_op.constant(0.0)
p = constant_op.constant([0.1, 0.4])
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p,
validate_args=True)
with self.cached_session():
sample = dist.sample()
with self.assertRaises(errors_impl.InvalidArgumentError):
sample.eval()
with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError,
"x > 0 did not hold"):
_ = relaxed_bernoulli.RelaxedBernoulli(
temperature, probs=p, validate_args=True)
# Error detected statically; no need to run the op.
def testDtype(self):
temperature = constant_op.constant(1.0, dtype=dtypes.float32)

View File

@ -1735,9 +1735,10 @@ class StreamingAUCTest(test.TestCase):
predictions = constant_op.constant(
[1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
_, update_op = metrics.streaming_auc(predictions, labels)
sess.run(variables.local_variables_initializer())
self.assertRaises(errors_impl.InvalidArgumentError, update_op.eval)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
r'predictions must be in \[0, 1\]'):
_, _ = metrics.streaming_auc(predictions, labels)
# Error detected statically; no need to run the op.
def testAllCorrect(self):
self.allCorrectAsExpected('ROC')

View File

@ -2195,7 +2195,7 @@ class _LazyBuilder(object):
if rank is not None:
if rank == 0:
raise ValueError(
'Feature (key: {}) cannot have rank 0. Give: {}'.format(
'Feature (key: {}) cannot have rank 0. Given: {}'.format(
key, feature_tensor))
return feature_tensor if rank != 1 else expand_dims(feature_tensor)
@ -2880,10 +2880,18 @@ class _IdentityCategoricalColumn(
if self.default_value is None:
# Fail if values are out-of-range.
assert_less = check_ops.assert_less(
values, num_buckets, data=(values, num_buckets),
values,
num_buckets,
data=(values, num_buckets),
message='Bucket index for categorical column '
'"{}" exceeds number of buckets'.format(self.name),
name='assert_less_than_num_buckets')
assert_greater = check_ops.assert_greater_equal(
values, zero, data=(values,),
values,
zero,
data=(values,),
message='Negative bucket index for categorical column "{}"'.format(
self.name),
name='assert_greater_or_equal_0')
with ops.control_dependencies((assert_less, assert_greater)):
values = array_ops.identity(values)

View File

@ -4391,11 +4391,10 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(1, -1, 0),
dense_shape=(2, 2))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with self.assertRaisesRegexp(errors.OpError, 'assert'):
id_weight_pair.id_tensor.eval()
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
'Negative bucket index for categorical column "aaa"'):
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
@test_util.run_deprecated_v1
def test_get_sparse_tensors_with_inputs_too_big(self):
@ -4404,11 +4403,10 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(1, 99, 0),
dense_shape=(2, 2))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with self.assertRaisesRegexp(errors.OpError, 'assert'):
id_weight_pair.id_tensor.eval()
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
'Bucket index for categorical column "aaa" exceeds number of buckets'):
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
@test_util.run_deprecated_v1
def test_get_sparse_tensors_with_default_value(self):

View File

@ -2631,7 +2631,7 @@ class FeatureTransformationCache(object):
if rank is not None:
if rank == 0:
raise ValueError(
'Feature (key: {}) cannot have rank 0. Give: {}'.format(
'Feature (key: {}) cannot have rank 0. Given: {}'.format(
key, feature_tensor))
return feature_tensor if rank != 1 else expand_dims(feature_tensor)
@ -3820,10 +3820,18 @@ class IdentityCategoricalColumn(
if self.default_value is None:
# Fail if values are out-of-range.
assert_less = check_ops.assert_less(
values, num_buckets, data=(values, num_buckets),
values,
num_buckets,
data=(values, num_buckets),
message='Bucket index for categorical column '
'"{}" exceeds number of buckets'.format(self.name),
name='assert_less_than_num_buckets')
assert_greater = check_ops.assert_greater_equal(
values, zero, data=(values,),
values,
zero,
data=(values,),
message='Negative bucket index for categorical column "{}"'.format(
self.name),
name='assert_greater_or_equal_0')
with ops.control_dependencies((assert_less, assert_greater)):
values = array_ops.identity(values)

View File

@ -5009,17 +5009,13 @@ class IdentityCategoricalColumnTest(test.TestCase):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)), values=(1, -1, 0), dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
self.evaluate(variables_lib.global_variables_initializer())
self.evaluate(lookup_ops.tables_initializer())
with self.assertRaisesRegexp(errors.OpError, 'assert'):
self.evaluate(id_weight_pair.id_tensor)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
'Negative bucket index for categorical column "aaa"'):
column.get_sparse_tensors(
fc.FeatureTransformationCache({
'aaa': inputs
}), None)
@test_util.run_deprecated_v1
def test_get_sparse_tensors_with_inputs_too_small(self):
@ -5034,17 +5030,13 @@ class IdentityCategoricalColumnTest(test.TestCase):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)), values=(1, 99, 0), dense_shape=(2, 2))
id_weight_pair = column.get_sparse_tensors(
fc.FeatureTransformationCache({
'aaa': inputs
}), None)
self.assertIsNone(id_weight_pair.weight_tensor)
self.evaluate(variables_lib.global_variables_initializer())
self.evaluate(lookup_ops.tables_initializer())
with self.assertRaisesRegexp(errors.OpError, 'assert'):
self.evaluate(id_weight_pair.id_tensor)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
'Bucket index for categorical column "aaa" exceeds number of buckets'):
column.get_sparse_tensors(
fc.FeatureTransformationCache({
'aaa': inputs
}), None)
@test_util.run_deprecated_v1
def test_get_sparse_tensors_with_inputs_too_big(self):

View File

@ -206,8 +206,7 @@ Corresponding y values:
First 6 elements of x:
\[2 2 3 3 6 6\]
First 6 elements of y:
\[20 2 3 30 60 6\]
"""
\[20 2 3 30 60 6\]"""
expected_error_msg_default = r"""big does not equal small
Condition x == y did not hold.
Indices of first 3 different values:
@ -221,8 +220,7 @@ Corresponding y values:
First 3 elements of x:
\[2 2 3\]
First 3 elements of y:
\[20 2 3\]
"""
\[20 2 3\]"""
expected_error_msg_short = r"""big does not equal small
Condition x == y did not hold.
Indices of first 2 different values:
@ -235,8 +233,7 @@ Corresponding y values:
First 2 elements of x:
\[2 2\]
First 2 elements of y:
\[20 2\]
"""
\[20 2\]"""
with context.eager_mode():
big = constant_op.constant([[2, 2], [3, 3], [6, 6]])
small = constant_op.constant([[20, 2], [3, 30], [60, 6]])
@ -380,27 +377,38 @@ class AssertNoneEqualTest(test.TestCase):
x = check_ops.assert_none_equal(t1, t2)
assert x is None
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_none_equal(1, 1, message="Custom error message")
def test_error_message_eager(self):
# Note that the following three strings are regexes
expected_error_msg_full = r"""0.0, 1.0, 2.0, 3.0, 4.0, 5.0"""
expected_error_msg_default = r"""0.0, 1.0, 2.0, \.\.\."""
expected_error_msg_short = r"""0.0, 1.0, \.\.\."""
expected_error_msg_full = r"""\[ *0\. +1\. +2\. +3\. +4\. +5\.\]"""
expected_error_msg_default = r"""\[ *0\. +1\. +2\.\]"""
expected_error_msg_short = r"""\[ *0\. +1\.\]"""
with context.eager_mode():
t = constant_op.constant(
np.array(range(6)), shape=[2, 3], dtype=np.float32)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
expected_error_msg_full):
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
expected_error_msg_full):
check_ops.assert_none_equal(
t, t, message="This is the error message.", summarize=10)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
expected_error_msg_full):
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
expected_error_msg_full):
check_ops.assert_none_equal(
t, t, message="This is the error message.", summarize=-1)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
expected_error_msg_default):
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
expected_error_msg_default):
check_ops.assert_none_equal(t, t, message="This is the error message.")
with self.assertRaisesRegexp(errors.InvalidArgumentError,
expected_error_msg_short):
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
expected_error_msg_short):
check_ops.assert_none_equal(
t, t, message="This is the error message.", summarize=2)
@ -492,7 +500,8 @@ class AssertAllCloseTest(test.TestCase):
def test_raises_when_atol_violated(self):
x = constant_op.constant(10., name="x")
y = constant_op.constant(10.2, name="y")
with self.assertRaisesOpError("x and y not equal to tolerance"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x and y not equal to tolerance"):
with ops.control_dependencies(
[check_ops.assert_near(x, y, atol=0.1,
message="failure message")]):
@ -503,7 +512,8 @@ class AssertAllCloseTest(test.TestCase):
def test_raises_when_default_rtol_violated(self):
x = constant_op.constant(0.1, name="x")
y = constant_op.constant(0.0, name="y")
with self.assertRaisesOpError("x and y not equal to tolerance"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x and y not equal to tolerance"):
with ops.control_dependencies(
[check_ops.assert_near(x, y, message="failure message")]):
out = array_ops.identity(x)
@ -523,7 +533,8 @@ class AssertLessTest(test.TestCase):
@test_util.run_deprecated_v1
def test_raises_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"failure message.*\n*.* x < y did not hold"):
with ops.control_dependencies(
[check_ops.assert_less(
small, small, message="failure message")]):
@ -535,7 +546,8 @@ class AssertLessTest(test.TestCase):
def test_raises_when_greater(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError("x < y did not hold"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x < y did not hold"):
with ops.control_dependencies([check_ops.assert_less(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@ -563,7 +575,7 @@ class AssertLessTest(test.TestCase):
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesRegexp(
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
(ValueError, errors.InvalidArgumentError),
(r"Incompatible shapes: \[3\] vs. \[2\]|"
"Dimensions must be equal, but are 3 and 2")):
@ -586,6 +598,13 @@ class AssertLessTest(test.TestCase):
x = check_ops.assert_less(t1, t2)
assert x is None
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_less(1, 1, message="Custom error message")
class AssertLessEqualTest(test.TestCase):
@ -602,7 +621,8 @@ class AssertLessEqualTest(test.TestCase):
def test_raises_when_greater(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError("fail"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_less_equal(
big, small, message="fail")]):
@ -632,7 +652,7 @@ class AssertLessEqualTest(test.TestCase):
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesRegexp(
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
(errors.InvalidArgumentError, ValueError),
(r"Incompatible shapes: \[2\] vs. \[3\]|"
r"Dimensions must be equal, but are 2 and 3")):
@ -650,6 +670,13 @@ class AssertLessEqualTest(test.TestCase):
out = array_ops.identity(larry)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_less_equal(1, 0, message="Custom error message")
class AssertGreaterTest(test.TestCase):
@ -657,7 +684,8 @@ class AssertGreaterTest(test.TestCase):
@test_util.run_deprecated_v1
def test_raises_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with self.assertRaisesOpError("fail"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_greater(
small, small, message="fail")]):
@ -669,7 +697,8 @@ class AssertGreaterTest(test.TestCase):
def test_raises_when_less(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError("x > y did not hold"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x > y did not hold"):
with ops.control_dependencies([check_ops.assert_greater(small, big)]):
out = array_ops.identity(big)
self.evaluate(out)
@ -697,7 +726,7 @@ class AssertGreaterTest(test.TestCase):
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesRegexp(
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
(errors.InvalidArgumentError, ValueError),
(r"Incompatible shapes: \[2\] vs. \[3\]|"
r"Dimensions must be equal, but are 2 and 3")):
@ -713,6 +742,13 @@ class AssertGreaterTest(test.TestCase):
out = array_ops.identity(larry)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_greater(0, 1, message="Custom error message")
class AssertGreaterEqualTest(test.TestCase):
@ -729,7 +765,8 @@ class AssertGreaterEqualTest(test.TestCase):
def test_raises_when_less(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError("fail"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_greater_equal(
small, big, message="fail")]):
@ -761,7 +798,7 @@ class AssertGreaterEqualTest(test.TestCase):
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesRegexp(
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
(errors.InvalidArgumentError, ValueError),
(r"Incompatible shapes: \[2\] vs. \[3\]|"
r"Dimensions must be equal, but are 2 and 3")):
@ -779,6 +816,13 @@ class AssertGreaterEqualTest(test.TestCase):
out = array_ops.identity(larry)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_greater_equal(0, 1, message="Custom error message")
class AssertNegativeTest(test.TestCase):
@ -793,7 +837,8 @@ class AssertNegativeTest(test.TestCase):
@test_util.run_deprecated_v1
def test_raises_when_positive(self):
doug = constant_op.constant([1, 2], name="doug")
with self.assertRaisesOpError("fail"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_negative(
doug, message="fail")]):
@ -804,7 +849,8 @@ class AssertNegativeTest(test.TestCase):
@test_util.run_deprecated_v1
def test_raises_when_zero(self):
claire = constant_op.constant([0], name="claire")
with self.assertRaisesOpError("x < 0 did not hold"):
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x < 0 did not hold"):
with ops.control_dependencies([check_ops.assert_negative(claire)]):
out = array_ops.identity(claire)
self.evaluate(out)
@ -820,7 +866,14 @@ class AssertNegativeTest(test.TestCase):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_negative(1, message="Custom error message")
# pylint:disable=g-error-prone-assert-raises
class AssertPositiveTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@ -861,6 +914,12 @@ class AssertPositiveTest(test.TestCase):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_positive(-1, message="Custom error message")
class EnsureShapeTest(test.TestCase):
@ -1402,6 +1461,12 @@ class AssertNonNegativeTest(test.TestCase):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_non_negative(-1, message="Custom error message")
class AssertNonPositiveTest(test.TestCase):
@ -1432,6 +1497,12 @@ class AssertNonPositiveTest(test.TestCase):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_non_positive(1, message="Custom error message")
class AssertIntegerTest(test.TestCase):

View File

@ -90,6 +90,287 @@ def _shape_and_dtype_str(tensor):
return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
def _unary_assert_doc(sym, sym_name):
"""Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.
Args:
sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
sym_name: English-language name for the op described by sym
Returns:
Decorator that adds the appropriate docstring to the function for symbol
`sym`.
"""
def _decorator(func):
"""Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
Args:
func: Function for a TensorFlow op
Returns:
Version of `func` with documentation attached.
"""
opname = func.__name__
cap_sym_name = sym_name.capitalize()
func.__doc__ = """
Assert the condition `x {sym}` holds element-wise.
When running in graph mode, you should add a dependency on this operation
to ensure that it runs. Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
output = tf.reduce_sum(x)
```
{sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
If `x` is empty this is trivially satisfied.
Args:
x: Numeric `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "{opname}".
Returns:
Op that raises `InvalidArgumentError` if `x {sym}` is False.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym}` is False. The check can be performed immediately during
eager execution or if `x` is statically known.
""".format(
sym=sym, sym_name=cap_sym_name, opname=opname)
return func
return _decorator
def _binary_assert_doc(sym):
"""Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.
Args:
sym: Binary operation symbol, i.e. "=="
Returns:
Decorator that adds the appropriate docstring to the function for
symbol `sym`.
"""
def _decorator(func):
"""Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
Args:
func: Function for a TensorFlow op
Returns:
A version of `func` with documentation attached.
"""
opname = func.__name__
func.__doc__ = """
Assert the condition `x {sym} y` holds element-wise.
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
When running in graph mode, you should add a dependency on this operation
to ensure that it runs. Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "{opname}".
Returns:
Op that raises `InvalidArgumentError` if `x {sym} y` is False.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym} y` is False. The check can be performed immediately during
eager execution or if `x` and `y` are statically known.
""".format(
sym=sym, opname=opname)
return func
return _decorator
def _make_assert_msg_data(sym, x, y, summarize, test_op):
"""Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.
Args:
sym: Mathematical symbol for the test to apply to pairs of tensor elements,
i.e. "=="
x: First input to the assertion after applying `convert_to_tensor()`
y: Second input to the assertion
summarize: Value of the "summarize" parameter to the original assert_* call;
tells how many elements of each tensor to print.
test_op: TensorFlow op that returns a Boolean tensor with True in each
position where the assertion is satisfied.
Returns:
List of tensors and scalars that, when stringified and concatenated,
will produce the error message string.
"""
# Prepare a message with first elements of x and y.
data = []
data.append('Condition x %s y did not hold.' % sym)
if summarize > 0:
if x.shape == y.shape and x.shape.as_list():
# If the shapes of x and y are the same (and not scalars),
# Get the values that actually differed and their indices.
# If shapes are different this information is more confusing
# than useful.
mask = math_ops.logical_not(test_op)
indices = array_ops.where(mask)
indices_np = indices.numpy()
x_vals = array_ops.boolean_mask(x, mask)
y_vals = array_ops.boolean_mask(y, mask)
num_vals = min(summarize, indices_np.shape[0])
data.append('Indices of first %d different values:' % num_vals)
data.append(indices_np[:num_vals])
data.append('Corresponding x values:')
data.append(x_vals.numpy().reshape((-1,))[:num_vals])
data.append('Corresponding y values:')
data.append(y_vals.numpy().reshape((-1,))[:num_vals])
# reshape((-1,)) is the fastest way to get a flat array view.
x_np = x.numpy().reshape((-1,))
y_np = y.numpy().reshape((-1,))
x_sum = min(x_np.size, summarize)
y_sum = min(y_np.size, summarize)
data.append('First %d elements of x:' % x_sum)
data.append(x_np[:x_sum])
data.append('First %d elements of y:' % y_sum)
data.append(y_np[:y_sum])
return data
def _pretty_print(data_item, summarize):
"""Format a data item for use in an error message in eager mode.
Args:
data_item: One of the items in the "data" argument to an assert_* function.
Can be a Tensor or a scalar value.
summarize: How many elements to retain of each tensor-valued entry in data.
Returns:
An appropriate string representation of data_item
"""
if isinstance(data_item, ops.Tensor):
arr = data_item.numpy()
if np.isscalar(arr):
# Tensor.numpy() returns a scalar for zero-dimensional tensors
return str(arr)
else:
flat = arr.reshape((-1,))
lst = [str(x) for x in flat[:summarize]]
if len(lst) < flat.size:
lst.append('...')
return str(lst)
else:
return str(data_item)
def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
message, name):
"""Generic binary elementwise assertion.
Implements the behavior described in _binary_assert_doc() above.
Args:
sym: Mathematical symbol for the test to apply to pairs of tensor elements,
i.e. "=="
opname: Name of the assert op in the public API, i.e. "assert_equal"
op_func: Function that, if passed the two Tensor inputs to the assertion (x
and y), will return the test to be passed to reduce_all() i.e.
static_func: Function that, if passed numpy ndarray versions of the two
inputs to the assertion, will return a Boolean ndarray with containing
True in all positions where the assertion PASSES.
i.e. lambda x,y: (x == y) for assert_equal()
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to the value of
`opname`.
Returns:
See docstring template in _binary_assert_doc().
"""
with ops.name_scope(name, opname, [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
test_op = op_func(x, y)
condition = math_ops.reduce_all(test_op)
if condition:
return
# If we get here, the assertion has failed.
# Default to printing 3 elements like control_flow_ops.Assert (used
# by graph mode) does. Also treat negative values as "print
# everything" for consistency with Tensor::SummarizeValue().
if summarize is None:
summarize = 3
elif summarize < 0:
summarize = 1e9 # Code below will find exact size of x and y.
if data is None:
data = _make_assert_msg_data(sym, x, y, summarize, test_op)
if message is not None:
data = [message] + list(data)
raise errors.InvalidArgumentError(
node_def=None,
op=None,
message=('\n'.join([_pretty_print(d, summarize) for d in data])))
else: # not context.executing_eagerly()
if data is None:
data = [
'Condition x %s y did not hold element-wise:' % sym,
'x (%s) = ' % x.name, x,
'y (%s) = ' % y.name, y
]
if message is not None:
data = [message] + list(data)
condition = math_ops.reduce_all(op_func(x, y))
x_static = tensor_util.constant_value(x)
y_static = tensor_util.constant_value(y)
if x_static is not None and y_static is not None:
condition_static = static_func(x_static, y_static).all()
_assert_static(condition_static, data)
return control_flow_ops.Assert(condition, data, summarize=summarize)
@tf_export(
'debugging.assert_proper_iterable',
v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
@ -155,30 +436,8 @@ def assert_negative_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
@deprecation.deprecated_endpoints('assert_negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < 0` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_negative(x)]):
output = tf.reduce_sum(x)
```
Negative means, for every element `x[i]` of `x`, we have `x[i] < 0`.
If `x` is empty this is trivially satisfied.
Args:
x: Numeric `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_negative".
Returns:
Op raising `InvalidArgumentError` unless `x` is all negative.
"""
@_unary_assert_doc('< 0', 'negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
@ -229,30 +488,8 @@ def assert_positive_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
@deprecation.deprecated_endpoints('assert_positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > 0` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_positive(x)]):
output = tf.reduce_sum(x)
```
Positive means, for every element `x[i]` of `x`, we have `x[i] > 0`.
If `x` is empty this is trivially satisfied.
Args:
x: Numeric `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_positive".
Returns:
Op raising `InvalidArgumentError` unless `x` is all positive.
"""
@_unary_assert_doc('> 0', 'positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
@ -304,31 +541,8 @@ def assert_non_negative_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
@deprecation.deprecated_endpoints('assert_non_negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x >= 0` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_non_negative(x)]):
output = tf.reduce_sum(x)
```
Non-negative means, for every element `x[i]` of `x`, we have `x[i] >= 0`.
If `x` is empty this is trivially satisfied.
Args:
x: Numeric `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Defaults to "assert_non_negative".
Returns:
Op raising `InvalidArgumentError` unless `x` is all non-negative.
"""
@_unary_assert_doc('>= 0', 'non-negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_non_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
@ -381,31 +595,8 @@ def assert_non_positive_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
@deprecation.deprecated_endpoints('assert_non_positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= 0` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_non_positive(x)]):
output = tf.reduce_sum(x)
```
Non-positive means, for every element `x[i]` of `x`, we have `x[i] <= 0`.
If `x` is empty this is trivially satisfied.
Args:
x: Numeric `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Defaults to "assert_non_positive".
Returns:
Op raising `InvalidArgumentError` unless `x` is all non-positive.
"""
@_unary_assert_doc('<= 0', 'non-positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_non_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
@ -457,109 +648,15 @@ def assert_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x == y` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_equal(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] == y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_equal".
Returns:
Op that raises `InvalidArgumentError` if `x == y` is False.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x == y` is False. The check can be performed immediately during eager
execution or if `x` and `y` are statically known.
"""
message = message or ''
@_binary_assert_doc('==')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
with ops.name_scope(name, 'assert_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
# Short-circuit if x and y are the same tensor.
if x is y:
return None if context.executing_eagerly() else control_flow_ops.no_op()
if context.executing_eagerly():
eq = math_ops.equal(x, y)
condition = math_ops.reduce_all(eq)
if not condition:
# Prepare a message with first elements of x and y.
summary_msg = ''
# Default to printing 3 elements like control_flow_ops.Assert (used
# by graph mode) does.
summarize = 3 if summarize is None else summarize
if summarize:
# reshape((-1,)) is the fastest way to get a flat array view.
x_np = x.numpy().reshape((-1,))
y_np = y.numpy().reshape((-1,))
x_sum = min(x_np.size, summarize)
y_sum = min(y_np.size, summarize)
summary_msg = ('First %d elements of x:\n%s\n'
'First %d elements of y:\n%s\n' %
(x_sum, x_np[:x_sum],
y_sum, y_np[:y_sum]))
index_and_values_str = ''
if x.shape == y.shape and x.shape.as_list():
# If the shapes of x and y are the same (and not scalars),
# Get the values that actually differed and their indices.
# If shapes are different this information is more confusing
# than useful.
mask = math_ops.logical_not(eq)
indices = array_ops.where(mask)
indices_np = indices.numpy()
x_vals = array_ops.boolean_mask(x, mask)
y_vals = array_ops.boolean_mask(y, mask)
summarize = min(summarize, indices_np.shape[0])
index_and_values_str = (
'Indices of first %s different values:\n%s\n'
'Corresponding x values:\n%s\n'
'Corresponding y values:\n%s\n' %
(summarize, indices_np[:summarize],
x_vals.numpy().reshape((-1,))[:summarize],
y_vals.numpy().reshape((-1,))[:summarize]))
raise errors.InvalidArgumentError(
node_def=None, op=None,
message=('%s\nCondition x == y did not hold.\n%s%s' %
(message or '', index_and_values_str, summary_msg)))
return
if data is None:
data = [
message,
'Condition x == y did not hold element-wise:',
'x (%s) = ' % x.name, x,
'y (%s) = ' % y.name, y
]
condition = math_ops.reduce_all(math_ops.equal(x, y))
x_static = tensor_util.constant_value(x)
y_static = tensor_util.constant_value(y)
if x_static is not None and y_static is not None:
condition_static = (x_static == y_static).all()
_assert_static(condition_static, data)
return control_flow_ops.Assert(condition, data, summarize=summarize)
return _binary_assert('==', 'assert_equal', math_ops.equal,
lambda x, y: (x == y),
x, y, data, summarize, message, name)
@tf_export('debugging.assert_none_equal', v1=[])
@ -602,54 +699,12 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
@deprecation.deprecated_endpoints('assert_none_equal')
@_binary_assert_doc('!=')
def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x != y` holds for all elements.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_none_equal(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] != y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Defaults to "assert_none_equal".
Returns:
Op that raises `InvalidArgumentError` if `x != y` is ever False.
"""
message = message or ''
with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
x_name = x.name
y_name = y.name
if data is None:
data = [
message,
'Condition x != y did not hold for every single element:',
'x (%s) = ' % x_name, x,
'y (%s) = ' % y_name, y
]
condition = math_ops.reduce_all(math_ops.not_equal(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
lambda x, y: (x != y), x, y, data, summarize, message,
name)
@tf_export('debugging.assert_near', v1=[])
@ -820,51 +875,10 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less', 'assert_less'])
@_binary_assert_doc('<')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < y` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_less(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] < y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_less".
Returns:
Op that raises `InvalidArgumentError` if `x < y` is False.
"""
message = message or ''
with ops.name_scope(name, 'assert_less', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
x_name = x.name
y_name = y.name
if data is None:
data = [
message,
'Condition x < y did not hold element-wise:',
'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
]
condition = math_ops.reduce_all(math_ops.less(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
return _binary_assert('<', 'assert_less', math_ops.less, lambda x, y: (x < y),
x, y, data, summarize, message, name)
@tf_export('debugging.assert_less_equal', v1=[])
@ -905,51 +919,11 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
@deprecation.deprecated_endpoints('assert_less_equal')
@_binary_assert_doc('<=')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= y` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_less_equal(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] <= y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_less_equal"
Returns:
Op that raises `InvalidArgumentError` if `x <= y` is False.
"""
message = message or ''
with ops.name_scope(name, 'assert_less_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
x_name = x.name
y_name = y.name
if data is None:
data = [
message,
'Condition x <= y did not hold element-wise:'
'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
]
condition = math_ops.reduce_all(math_ops.less_equal(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
lambda x, y: (x <= y), x, y, data, summarize, message,
name)
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
@ -989,51 +963,11 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > y` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_greater(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] > y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_greater".
Returns:
Op that raises `InvalidArgumentError` if `x > y` is False.
"""
message = message or ''
with ops.name_scope(name, 'assert_greater', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
x_name = x.name
y_name = y.name
if data is None:
data = [
message,
'Condition x > y did not hold element-wise:'
'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
]
condition = math_ops.reduce_all(math_ops.greater(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
@_binary_assert_doc('>')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
return _binary_assert('>', 'assert_greater', math_ops.greater,
lambda x, y: (x > y),
x, y, data, summarize, message, name)
@tf_export('debugging.assert_greater_equal', v1=[])
@ -1075,53 +1009,12 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
@deprecation.deprecated_endpoints('assert_greater_equal')
@_binary_assert_doc('>=')
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None):
"""Assert the condition `x >= y` holds element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_greater_equal(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] >= y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to
"assert_greater_equal"
Returns:
Op that raises `InvalidArgumentError` if `x >= y` is False.
"""
message = message or ''
with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
x_name = x.name
y_name = y.name
if data is None:
data = [
message,
'Condition x >= y did not hold element-wise:'
'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
]
condition = math_ops.reduce_all(math_ops.greater_equal(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
lambda x, y: (x >= y), x, y, data, summarize, message,
name)
def _assert_rank_condition(
@ -2266,3 +2159,4 @@ def ensure_shape(x, shape, name=None):
def _ensure_shape_grad(op, grad):
del op # Unused.
return grad