1011 lines
35 KiB
Python
1011 lines
35 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tensorflow.ops.test_util."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import copy
|
|
import random
|
|
import threading
|
|
import unittest
|
|
import weakref
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
from google.protobuf import text_format
|
|
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.protobuf import meta_graph_pb2
|
|
from tensorflow.python.compat import compat
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import combinations
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import lookup_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import random_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import googletest
|
|
|
|
|
|
class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|
|
|
def test_assert_ops_in_graph(self):
|
|
with ops.Graph().as_default():
|
|
constant_op.constant(["hello", "taffy"], name="hello")
|
|
test_util.assert_ops_in_graph({"hello": "Const"}, ops.get_default_graph())
|
|
|
|
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
|
|
{"bye": "Const"}, ops.get_default_graph())
|
|
|
|
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
|
|
{"hello": "Variable"}, ops.get_default_graph())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_session_functions(self):
|
|
with self.test_session() as sess:
|
|
sess_ref = weakref.ref(sess)
|
|
with self.cached_session(graph=None, config=None) as sess2:
|
|
# We make sure that sess2 is sess.
|
|
assert sess2 is sess
|
|
# We make sure we raise an exception if we use cached_session with
|
|
# different values.
|
|
with self.assertRaises(ValueError):
|
|
with self.cached_session(graph=ops.Graph()) as sess2:
|
|
pass
|
|
with self.assertRaises(ValueError):
|
|
with self.cached_session(force_gpu=True) as sess2:
|
|
pass
|
|
# We make sure that test_session will cache the session even after the
|
|
# with scope.
|
|
assert not sess_ref()._closed
|
|
with self.session() as unique_sess:
|
|
unique_sess_ref = weakref.ref(unique_sess)
|
|
with self.session() as sess2:
|
|
assert sess2 is not unique_sess
|
|
# We make sure the session is closed when we leave the with statement.
|
|
assert unique_sess_ref()._closed
|
|
|
|
def test_assert_equal_graph_def(self):
|
|
with ops.Graph().as_default() as g:
|
|
def_empty = g.as_graph_def()
|
|
constant_op.constant(5, name="five")
|
|
constant_op.constant(7, name="seven")
|
|
def_57 = g.as_graph_def()
|
|
with ops.Graph().as_default() as g:
|
|
constant_op.constant(7, name="seven")
|
|
constant_op.constant(5, name="five")
|
|
def_75 = g.as_graph_def()
|
|
# Comparing strings is order dependent
|
|
self.assertNotEqual(str(def_57), str(def_75))
|
|
# assert_equal_graph_def doesn't care about order
|
|
test_util.assert_equal_graph_def(def_57, def_75)
|
|
# Compare two unequal graphs
|
|
with self.assertRaisesRegex(AssertionError,
|
|
r"^Found unexpected node '{{node seven}}"):
|
|
test_util.assert_equal_graph_def(def_57, def_empty)
|
|
|
|
def test_assert_equal_graph_def_hash_table(self):
|
|
def get_graph_def():
|
|
with ops.Graph().as_default() as g:
|
|
x = constant_op.constant([2, 9], name="x")
|
|
keys = constant_op.constant([1, 2], name="keys")
|
|
values = constant_op.constant([3, 4], name="values")
|
|
default = constant_op.constant(-1, name="default")
|
|
table = lookup_ops.StaticHashTable(
|
|
lookup_ops.KeyValueTensorInitializer(keys, values), default)
|
|
_ = table.lookup(x)
|
|
return g.as_graph_def()
|
|
def_1 = get_graph_def()
|
|
def_2 = get_graph_def()
|
|
# The unique shared_name of each table makes the graph unequal.
|
|
with self.assertRaisesRegex(AssertionError, "hash_table_"):
|
|
test_util.assert_equal_graph_def(def_1, def_2,
|
|
hash_table_shared_name=False)
|
|
# That can be ignored. (NOTE: modifies GraphDefs in-place.)
|
|
test_util.assert_equal_graph_def(def_1, def_2,
|
|
hash_table_shared_name=True)
|
|
|
|
def testIsGoogleCudaEnabled(self):
|
|
# The test doesn't assert anything. It ensures the py wrapper
|
|
# function is generated correctly.
|
|
if test_util.IsGoogleCudaEnabled():
|
|
print("GoogleCuda is enabled")
|
|
else:
|
|
print("GoogleCuda is disabled")
|
|
|
|
def testIsMklEnabled(self):
|
|
# This test doesn't assert anything.
|
|
# It ensures the py wrapper function is generated correctly.
|
|
if test_util.IsMklEnabled():
|
|
print("MKL is enabled")
|
|
else:
|
|
print("MKL is disabled")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertProtoEqualsStr(self):
|
|
|
|
graph_str = "node { name: 'w1' op: 'params' }"
|
|
graph_def = graph_pb2.GraphDef()
|
|
text_format.Merge(graph_str, graph_def)
|
|
|
|
# test string based comparison
|
|
self.assertProtoEquals(graph_str, graph_def)
|
|
|
|
# test original comparison
|
|
self.assertProtoEquals(graph_def, graph_def)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertProtoEqualsAny(self):
|
|
# Test assertProtoEquals with a protobuf.Any field.
|
|
meta_graph_def_str = """
|
|
meta_info_def {
|
|
meta_graph_version: "outer"
|
|
any_info {
|
|
[type.googleapis.com/tensorflow.MetaGraphDef] {
|
|
meta_info_def {
|
|
meta_graph_version: "inner"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
meta_graph_def_outer = meta_graph_pb2.MetaGraphDef()
|
|
meta_graph_def_outer.meta_info_def.meta_graph_version = "outer"
|
|
meta_graph_def_inner = meta_graph_pb2.MetaGraphDef()
|
|
meta_graph_def_inner.meta_info_def.meta_graph_version = "inner"
|
|
meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner)
|
|
self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer)
|
|
self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer)
|
|
|
|
# Check if the assertion failure message contains the content of
|
|
# the inner proto.
|
|
with self.assertRaisesRegex(AssertionError, r'meta_graph_version: "inner"'):
|
|
self.assertProtoEquals("", meta_graph_def_outer)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testNDArrayNear(self):
|
|
a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
|
a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
|
a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
|
|
self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
|
|
self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testCheckedThreadSucceeds(self):
|
|
|
|
def noop(ev):
|
|
ev.set()
|
|
|
|
event_arg = threading.Event()
|
|
|
|
self.assertFalse(event_arg.is_set())
|
|
t = self.checkedThread(target=noop, args=(event_arg,))
|
|
t.start()
|
|
t.join()
|
|
self.assertTrue(event_arg.is_set())
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testCheckedThreadFails(self):
|
|
|
|
def err_func():
|
|
return 1 // 0
|
|
|
|
t = self.checkedThread(target=err_func)
|
|
t.start()
|
|
with self.assertRaises(self.failureException) as fe:
|
|
t.join()
|
|
self.assertTrue("integer division or modulo by zero" in str(fe.exception))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testCheckedThreadWithWrongAssertionFails(self):
|
|
x = 37
|
|
|
|
def err_func():
|
|
self.assertTrue(x < 10)
|
|
|
|
t = self.checkedThread(target=err_func)
|
|
t.start()
|
|
with self.assertRaises(self.failureException) as fe:
|
|
t.join()
|
|
self.assertTrue("False is not true" in str(fe.exception))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testMultipleThreadsWithOneFailure(self):
|
|
|
|
def err_func(i):
|
|
self.assertTrue(i != 7)
|
|
|
|
threads = [
|
|
self.checkedThread(
|
|
target=err_func, args=(i,)) for i in range(10)
|
|
]
|
|
for t in threads:
|
|
t.start()
|
|
for i, t in enumerate(threads):
|
|
if i == 7:
|
|
with self.assertRaises(self.failureException):
|
|
t.join()
|
|
else:
|
|
t.join()
|
|
|
|
def _WeMustGoDeeper(self, msg):
|
|
with self.assertRaisesOpError(msg):
|
|
with ops.Graph().as_default():
|
|
node_def = ops._NodeDef("IntOutput", "name")
|
|
node_def_orig = ops._NodeDef("IntOutput", "orig")
|
|
op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
|
|
op = ops.Operation(node_def, ops.get_default_graph(),
|
|
original_op=op_orig)
|
|
raise errors.UnauthenticatedError(node_def, op, "true_err")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
|
|
with self.assertRaises(AssertionError):
|
|
self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
|
|
|
|
self._WeMustGoDeeper("true_err")
|
|
self._WeMustGoDeeper("name")
|
|
self._WeMustGoDeeper("orig")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseTensors(self):
|
|
a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
|
a = constant_op.constant(a_raw_data)
|
|
b = math_ops.add(1, constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
|
|
self.assertAllClose(a, b)
|
|
self.assertAllClose(a, a_raw_data)
|
|
|
|
a_dict = {"key": a}
|
|
b_dict = {"key": b}
|
|
self.assertAllClose(a_dict, b_dict)
|
|
|
|
x_list = [a, b]
|
|
y_list = [a_raw_data, b]
|
|
self.assertAllClose(x_list, y_list)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseScalars(self):
|
|
self.assertAllClose(7, 7 + 1e-8)
|
|
with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
|
|
self.assertAllClose(7, 7 + 1e-5)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseList(self):
|
|
with self.assertRaisesRegex(AssertionError, r"not close dif"):
|
|
self.assertAllClose([0], [1])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseDictToNonDict(self):
|
|
with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
|
|
self.assertAllClose(1, {"a": 1})
|
|
with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
|
|
self.assertAllClose({"a": 1}, 1)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseNamedtuples(self):
|
|
a = 7
|
|
b = (2., 3.)
|
|
c = np.ones((3, 2, 4)) * 7.
|
|
expected = {"a": a, "b": b, "c": c}
|
|
my_named_tuple = collections.namedtuple("MyNamedTuple", ["a", "b", "c"])
|
|
|
|
# Identity.
|
|
self.assertAllClose(expected, my_named_tuple(a=a, b=b, c=c))
|
|
self.assertAllClose(
|
|
my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseDicts(self):
|
|
a = 7
|
|
b = (2., 3.)
|
|
c = np.ones((3, 2, 4)) * 7.
|
|
expected = {"a": a, "b": b, "c": c}
|
|
|
|
# Identity.
|
|
self.assertAllClose(expected, expected)
|
|
self.assertAllClose(expected, dict(expected))
|
|
|
|
# With each item removed.
|
|
for k in expected:
|
|
actual = dict(expected)
|
|
del actual[k]
|
|
with self.assertRaisesRegex(AssertionError, r"mismatched keys"):
|
|
self.assertAllClose(expected, actual)
|
|
|
|
# With each item changed.
|
|
with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
|
|
self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c})
|
|
with self.assertRaisesRegex(AssertionError, r"Shape mismatch"):
|
|
self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c})
|
|
c_copy = np.array(c)
|
|
c_copy[1, 1, 1] += 1e-5
|
|
with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
|
|
self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseListOfNamedtuples(self):
|
|
my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"])
|
|
l1 = [
|
|
my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])),
|
|
my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]]))
|
|
]
|
|
l2 = [
|
|
([[2.3, 2.5]], [[0.97, 0.96]]),
|
|
([[3.3, 3.5]], [[0.98, 0.99]]),
|
|
]
|
|
self.assertAllClose(l1, l2)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAllCloseNestedStructure(self):
|
|
a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])}
|
|
self.assertAllClose(a, a)
|
|
|
|
b = copy.deepcopy(a)
|
|
self.assertAllClose(a, b)
|
|
|
|
# Test mismatched values
|
|
b["y"][1][0]["nested"]["n"] = 4.2
|
|
with self.assertRaisesRegex(AssertionError,
|
|
r"\[y\]\[1\]\[0\]\[nested\]\[n\]"):
|
|
self.assertAllClose(a, b)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testArrayNear(self):
|
|
a = [1, 2]
|
|
b = [1, 2, 5]
|
|
with self.assertRaises(AssertionError):
|
|
self.assertArrayNear(a, b, 0.001)
|
|
a = [1, 2]
|
|
b = [[1, 2], [3, 4]]
|
|
with self.assertRaises(TypeError):
|
|
self.assertArrayNear(a, b, 0.001)
|
|
a = [1, 2]
|
|
b = [1, 2]
|
|
self.assertArrayNear(a, b, 0.001)
|
|
|
|
@test_util.skip_if(True) # b/117665998
|
|
def testForceGPU(self):
|
|
with self.assertRaises(errors.InvalidArgumentError):
|
|
with self.test_session(force_gpu=True):
|
|
# this relies on us not having a GPU implementation for assert, which
|
|
# seems sensible
|
|
x = constant_op.constant(True)
|
|
y = [15]
|
|
control_flow_ops.Assert(x, y).run()
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllCloseAccordingToType(self):
|
|
# test plain int
|
|
self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8)
|
|
|
|
# test float64
|
|
self.assertAllCloseAccordingToType(
|
|
np.asarray([1e-8], dtype=np.float64),
|
|
np.asarray([2e-8], dtype=np.float64),
|
|
rtol=1e-8, atol=1e-8
|
|
)
|
|
|
|
self.assertAllCloseAccordingToType(
|
|
constant_op.constant([1e-8], dtype=dtypes.float64),
|
|
constant_op.constant([2e-8], dtype=dtypes.float64),
|
|
rtol=1e-8,
|
|
atol=1e-8)
|
|
|
|
with (self.assertRaises(AssertionError)):
|
|
self.assertAllCloseAccordingToType(
|
|
np.asarray([1e-7], dtype=np.float64),
|
|
np.asarray([2e-7], dtype=np.float64),
|
|
rtol=1e-8, atol=1e-8
|
|
)
|
|
|
|
# test float32
|
|
self.assertAllCloseAccordingToType(
|
|
np.asarray([1e-7], dtype=np.float32),
|
|
np.asarray([2e-7], dtype=np.float32),
|
|
rtol=1e-8, atol=1e-8,
|
|
float_rtol=1e-7, float_atol=1e-7
|
|
)
|
|
|
|
self.assertAllCloseAccordingToType(
|
|
constant_op.constant([1e-7], dtype=dtypes.float32),
|
|
constant_op.constant([2e-7], dtype=dtypes.float32),
|
|
rtol=1e-8,
|
|
atol=1e-8,
|
|
float_rtol=1e-7,
|
|
float_atol=1e-7)
|
|
|
|
with (self.assertRaises(AssertionError)):
|
|
self.assertAllCloseAccordingToType(
|
|
np.asarray([1e-6], dtype=np.float32),
|
|
np.asarray([2e-6], dtype=np.float32),
|
|
rtol=1e-8, atol=1e-8,
|
|
float_rtol=1e-7, float_atol=1e-7
|
|
)
|
|
|
|
# test float16
|
|
self.assertAllCloseAccordingToType(
|
|
np.asarray([1e-4], dtype=np.float16),
|
|
np.asarray([2e-4], dtype=np.float16),
|
|
rtol=1e-8, atol=1e-8,
|
|
float_rtol=1e-7, float_atol=1e-7,
|
|
half_rtol=1e-4, half_atol=1e-4
|
|
)
|
|
|
|
self.assertAllCloseAccordingToType(
|
|
constant_op.constant([1e-4], dtype=dtypes.float16),
|
|
constant_op.constant([2e-4], dtype=dtypes.float16),
|
|
rtol=1e-8,
|
|
atol=1e-8,
|
|
float_rtol=1e-7,
|
|
float_atol=1e-7,
|
|
half_rtol=1e-4,
|
|
half_atol=1e-4)
|
|
|
|
with (self.assertRaises(AssertionError)):
|
|
self.assertAllCloseAccordingToType(
|
|
np.asarray([1e-3], dtype=np.float16),
|
|
np.asarray([2e-3], dtype=np.float16),
|
|
rtol=1e-8, atol=1e-8,
|
|
float_rtol=1e-7, float_atol=1e-7,
|
|
half_rtol=1e-4, half_atol=1e-4
|
|
)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllEqual(self):
|
|
i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i")
|
|
j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j")
|
|
k = math_ops.add(i, j, name="k")
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllEqual([100] * 3, i)
|
|
self.assertAllEqual([120] * 3, k)
|
|
self.assertAllEqual([20] * 3, j)
|
|
|
|
with self.assertRaisesRegex(AssertionError, r"not equal lhs"):
|
|
self.assertAllEqual([0] * 3, k)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertNotAllEqual(self):
|
|
i = variables.Variable([100], dtype=dtypes.int32, name="i")
|
|
j = constant_op.constant([20], dtype=dtypes.int32, name="j")
|
|
k = math_ops.add(i, j, name="k")
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertNotAllEqual([100] * 3, i)
|
|
self.assertNotAllEqual([120] * 3, k)
|
|
self.assertNotAllEqual([20] * 3, j)
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, r"two values are equal at all elements.*extra message"):
|
|
self.assertNotAllEqual([120], k, msg="extra message")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertNotAllClose(self):
|
|
# Test with arrays
|
|
self.assertNotAllClose([0.1], [0.2])
|
|
with self.assertRaises(AssertionError):
|
|
self.assertNotAllClose([-1.0, 2.0], [-1.0, 2.0])
|
|
|
|
# Test with tensors
|
|
x = constant_op.constant([1.0, 1.0], name="x")
|
|
y = math_ops.add(x, x)
|
|
|
|
self.assertAllClose([2.0, 2.0], y)
|
|
self.assertNotAllClose([0.9, 1.0], x)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertNotAllClose([1.0, 1.0], x)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertNotAllCloseRTol(self):
|
|
# Test with arrays
|
|
with self.assertRaises(AssertionError):
|
|
self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], rtol=0.2)
|
|
|
|
# Test with tensors
|
|
x = constant_op.constant([1.0, 1.0], name="x")
|
|
y = math_ops.add(x, x)
|
|
|
|
self.assertAllClose([2.0, 2.0], y)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertNotAllClose([0.9, 1.0], x, rtol=0.2)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertNotAllCloseATol(self):
|
|
# Test with arrays
|
|
with self.assertRaises(AssertionError):
|
|
self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], atol=0.2)
|
|
|
|
# Test with tensors
|
|
x = constant_op.constant([1.0, 1.0], name="x")
|
|
y = math_ops.add(x, x)
|
|
|
|
self.assertAllClose([2.0, 2.0], y)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertNotAllClose([0.9, 1.0], x, atol=0.2)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllGreaterLess(self):
|
|
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
|
|
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
|
|
z = math_ops.add(x, y)
|
|
|
|
self.assertAllClose([110.0, 120.0, 130.0], z)
|
|
|
|
self.assertAllGreater(x, 95.0)
|
|
self.assertAllLess(x, 125.0)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllGreater(x, 105.0)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllGreater(x, 125.0)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllLess(x, 115.0)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllLess(x, 95.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllGreaterLessEqual(self):
|
|
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
|
|
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
|
|
z = math_ops.add(x, y)
|
|
|
|
self.assertAllEqual([110.0, 120.0, 130.0], z)
|
|
|
|
self.assertAllGreaterEqual(x, 95.0)
|
|
self.assertAllLessEqual(x, 125.0)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllGreaterEqual(x, 105.0)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllGreaterEqual(x, 125.0)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllLessEqual(x, 115.0)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllLessEqual(x, 95.0)
|
|
|
|
def testAssertAllInRangeWithNonNumericValuesFails(self):
|
|
s1 = constant_op.constant("Hello, ", name="s1")
|
|
c = constant_op.constant([1 + 2j, -3 + 5j], name="c")
|
|
b = constant_op.constant([False, True], name="b")
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(s1, 0.0, 1.0)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(c, 0.0, 1.0)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(b, 0, 1)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllInRange(self):
|
|
x = constant_op.constant([10.0, 15.0], name="x")
|
|
self.assertAllInRange(x, 10, 15)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(x, 10, 15, open_lower_bound=True)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(x, 10, 15, open_upper_bound=True)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(
|
|
x, 10, 15, open_lower_bound=True, open_upper_bound=True)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllInRangeScalar(self):
|
|
x = constant_op.constant(10.0, name="x")
|
|
nan = constant_op.constant(np.nan, name="nan")
|
|
self.assertAllInRange(x, 5, 15)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(nan, 5, 15)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(x, 10, 15, open_lower_bound=True)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(x, 1, 2)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllInRangeErrorMessageEllipses(self):
|
|
x_init = np.array([[10.0, 15.0]] * 12)
|
|
x = constant_op.constant(x_init, name="x")
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(x, 5, 10)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllInRangeDetectsNaNs(self):
|
|
x = constant_op.constant(
|
|
[[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x")
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(x, 0.0, 2.0)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllInRangeWithInfinities(self):
|
|
x = constant_op.constant([10.0, np.inf], name="x")
|
|
self.assertAllInRange(x, 10, np.inf)
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInRange(x, 10, np.inf, open_upper_bound=True)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssertAllInSet(self):
|
|
b = constant_op.constant([True, False], name="b")
|
|
x = constant_op.constant([13, 37], name="x")
|
|
|
|
self.assertAllInSet(b, [False, True])
|
|
self.assertAllInSet(b, (False, True))
|
|
self.assertAllInSet(b, {False, True})
|
|
self.assertAllInSet(x, [0, 13, 37, 42])
|
|
self.assertAllInSet(x, (0, 13, 37, 42))
|
|
self.assertAllInSet(x, {0, 13, 37, 42})
|
|
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInSet(b, [False])
|
|
with self.assertRaises(AssertionError):
|
|
self.assertAllInSet(x, (42,))
|
|
|
|
def testRandomSeed(self):
|
|
# Call setUp again for WithCApi case (since it makes a new default graph
|
|
# after setup).
|
|
# TODO(skyewm): remove this when C API is permanently enabled.
|
|
with context.eager_mode():
|
|
self.setUp()
|
|
a = random.randint(1, 1000)
|
|
a_np_rand = np.random.rand(1)
|
|
a_rand = random_ops.random_normal([1])
|
|
# ensure that randomness in multiple testCases is deterministic.
|
|
self.setUp()
|
|
b = random.randint(1, 1000)
|
|
b_np_rand = np.random.rand(1)
|
|
b_rand = random_ops.random_normal([1])
|
|
self.assertEqual(a, b)
|
|
self.assertEqual(a_np_rand, b_np_rand)
|
|
self.assertAllEqual(a_rand, b_rand)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_callable_evaluate(self):
|
|
def model():
|
|
return resource_variable_ops.ResourceVariable(
|
|
name="same_name",
|
|
initial_value=1) + 1
|
|
with context.eager_mode():
|
|
self.assertEqual(2, self.evaluate(model))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_nested_tensors_evaluate(self):
|
|
expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
|
|
nested = {"a": constant_op.constant(1),
|
|
"b": constant_op.constant(2),
|
|
"nested": {"d": constant_op.constant(3),
|
|
"e": constant_op.constant(4)}}
|
|
|
|
self.assertEqual(expected, self.evaluate(nested))
|
|
|
|
def test_run_in_graph_and_eager_modes(self):
|
|
l = []
|
|
def inc(self, with_brackets):
|
|
del self # self argument is required by run_in_graph_and_eager_modes.
|
|
mode = "eager" if context.executing_eagerly() else "graph"
|
|
with_brackets = "with_brackets" if with_brackets else "without_brackets"
|
|
l.append((with_brackets, mode))
|
|
|
|
f = test_util.run_in_graph_and_eager_modes(inc)
|
|
f(self, with_brackets=False)
|
|
f = test_util.run_in_graph_and_eager_modes()(inc)
|
|
f(self, with_brackets=True)
|
|
|
|
self.assertEqual(len(l), 4)
|
|
self.assertEqual(set(l), {
|
|
("with_brackets", "graph"),
|
|
("with_brackets", "eager"),
|
|
("without_brackets", "graph"),
|
|
("without_brackets", "eager"),
|
|
})
|
|
|
|
def test_get_node_def_from_graph(self):
|
|
graph_def = graph_pb2.GraphDef()
|
|
node_foo = graph_def.node.add()
|
|
node_foo.name = "foo"
|
|
self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
|
|
self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
|
|
|
|
def test_run_in_eager_and_graph_modes_test_class(self):
|
|
msg = "`run_in_graph_and_eager_modes` only supports test methods.*"
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
|
|
@test_util.run_in_graph_and_eager_modes()
|
|
class Foo(object):
|
|
pass
|
|
del Foo # Make pylint unused happy.
|
|
|
|
def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self):
|
|
modes = []
|
|
def _test(self):
|
|
if not context.executing_eagerly():
|
|
self.skipTest("Skipping in graph mode")
|
|
modes.append("eager" if context.executing_eagerly() else "graph")
|
|
test_util.run_in_graph_and_eager_modes(_test)(self)
|
|
self.assertEqual(modes, ["eager"])
|
|
|
|
def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self):
|
|
modes = []
|
|
def _test(self):
|
|
if context.executing_eagerly():
|
|
self.skipTest("Skipping in eager mode")
|
|
modes.append("eager" if context.executing_eagerly() else "graph")
|
|
test_util.run_in_graph_and_eager_modes(_test)(self)
|
|
self.assertEqual(modes, ["graph"])
|
|
|
|
def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
|
|
modes = []
|
|
mode_name = lambda: "eager" if context.executing_eagerly() else "graph"
|
|
|
|
class ExampleTest(test_util.TensorFlowTestCase):
|
|
|
|
def runTest(self):
|
|
pass
|
|
|
|
def setUp(self):
|
|
modes.append("setup_" + mode_name())
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBody(self):
|
|
modes.append("run_" + mode_name())
|
|
|
|
e = ExampleTest()
|
|
e.setUp()
|
|
e.testBody()
|
|
|
|
self.assertEqual(modes[1:2], ["run_graph"])
|
|
self.assertEqual(modes[2:], ["setup_eager", "run_eager"])
|
|
|
|
@parameterized.named_parameters(dict(testcase_name="argument",
|
|
arg=True))
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg):
|
|
self.assertEqual(arg, True)
|
|
|
|
@combinations.generate(combinations.combine(arg=True))
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_run_in_graph_and_eager_works_with_combinations(self, arg):
|
|
self.assertEqual(arg, True)
|
|
|
|
def test_build_as_function_and_v1_graph(self):
|
|
|
|
class GraphModeAndFunctionTest(parameterized.TestCase):
|
|
|
|
def __init__(inner_self): # pylint: disable=no-self-argument
|
|
super(GraphModeAndFunctionTest, inner_self).__init__()
|
|
inner_self.graph_mode_tested = False
|
|
inner_self.inside_function_tested = False
|
|
|
|
def runTest(self):
|
|
del self
|
|
|
|
@test_util.build_as_function_and_v1_graph
|
|
def test_modes(inner_self): # pylint: disable=no-self-argument
|
|
if ops.inside_function():
|
|
self.assertFalse(inner_self.inside_function_tested)
|
|
inner_self.inside_function_tested = True
|
|
else:
|
|
self.assertFalse(inner_self.graph_mode_tested)
|
|
inner_self.graph_mode_tested = True
|
|
|
|
test_object = GraphModeAndFunctionTest()
|
|
test_object.test_modes_v1_graph()
|
|
test_object.test_modes_function()
|
|
self.assertTrue(test_object.graph_mode_tested)
|
|
self.assertTrue(test_object.inside_function_tested)
|
|
|
|
def test_with_forward_compatibility_horizons(self):
|
|
|
|
tested_codepaths = set()
|
|
def some_function_with_forward_compat_behavior():
|
|
if compat.forward_compatible(2050, 1, 1):
|
|
tested_codepaths.add("future")
|
|
else:
|
|
tested_codepaths.add("present")
|
|
|
|
@test_util.with_forward_compatibility_horizons(None, [2051, 1, 1])
|
|
def some_test(self):
|
|
del self # unused
|
|
some_function_with_forward_compat_behavior()
|
|
|
|
some_test(None)
|
|
self.assertEqual(tested_codepaths, set(["present", "future"]))
|
|
|
|
|
|
class SkipTestTest(test_util.TensorFlowTestCase):
|
|
|
|
def _verify_test_in_set_up_or_tear_down(self):
|
|
with self.assertRaises(unittest.SkipTest):
|
|
with test_util.skip_if_error(self, ValueError,
|
|
["foo bar", "test message"]):
|
|
raise ValueError("test message")
|
|
try:
|
|
with self.assertRaisesRegex(ValueError, "foo bar"):
|
|
with test_util.skip_if_error(self, ValueError, "test message"):
|
|
raise ValueError("foo bar")
|
|
except unittest.SkipTest:
|
|
raise RuntimeError("Test is not supposed to skip.")
|
|
|
|
def setUp(self):
|
|
super(SkipTestTest, self).setUp()
|
|
self._verify_test_in_set_up_or_tear_down()
|
|
|
|
def tearDown(self):
|
|
super(SkipTestTest, self).tearDown()
|
|
self._verify_test_in_set_up_or_tear_down()
|
|
|
|
def test_skip_if_error_should_skip(self):
|
|
with self.assertRaises(unittest.SkipTest):
|
|
with test_util.skip_if_error(self, ValueError, "test message"):
|
|
raise ValueError("test message")
|
|
|
|
def test_skip_if_error_should_skip_with_list(self):
|
|
with self.assertRaises(unittest.SkipTest):
|
|
with test_util.skip_if_error(self, ValueError,
|
|
["foo bar", "test message"]):
|
|
raise ValueError("test message")
|
|
|
|
def test_skip_if_error_should_skip_without_expected_message(self):
|
|
with self.assertRaises(unittest.SkipTest):
|
|
with test_util.skip_if_error(self, ValueError):
|
|
raise ValueError("test message")
|
|
|
|
def test_skip_if_error_should_skip_without_error_message(self):
|
|
with self.assertRaises(unittest.SkipTest):
|
|
with test_util.skip_if_error(self, ValueError):
|
|
raise ValueError()
|
|
|
|
def test_skip_if_error_should_raise_message_mismatch(self):
|
|
try:
|
|
with self.assertRaisesRegex(ValueError, "foo bar"):
|
|
with test_util.skip_if_error(self, ValueError, "test message"):
|
|
raise ValueError("foo bar")
|
|
except unittest.SkipTest:
|
|
raise RuntimeError("Test is not supposed to skip.")
|
|
|
|
def test_skip_if_error_should_raise_no_message(self):
|
|
try:
|
|
with self.assertRaisesRegex(ValueError, ""):
|
|
with test_util.skip_if_error(self, ValueError, "test message"):
|
|
raise ValueError()
|
|
except unittest.SkipTest:
|
|
raise RuntimeError("Test is not supposed to skip.")
|
|
|
|
|
|
# Its own test case to reproduce variable sharing issues which only pop up when
|
|
# setUp() is overridden and super() is not called.
|
|
class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):
|
|
|
|
def setUp(self):
|
|
pass # Intentionally does not call TensorFlowTestCase's super()
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_no_variable_sharing(self):
|
|
variable_scope.get_variable(
|
|
name="step_size",
|
|
initializer=np.array(1e-5, np.float32),
|
|
use_resource=True,
|
|
trainable=False)
|
|
|
|
|
|
class GarbageCollectionTest(test_util.TensorFlowTestCase):
|
|
|
|
def test_no_reference_cycle_decorator(self):
|
|
|
|
class ReferenceCycleTest(object):
|
|
|
|
def __init__(inner_self): # pylint: disable=no-self-argument
|
|
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
|
|
|
|
@test_util.assert_no_garbage_created
|
|
def test_has_cycle(self):
|
|
a = []
|
|
a.append(a)
|
|
|
|
@test_util.assert_no_garbage_created
|
|
def test_has_no_cycle(self):
|
|
pass
|
|
|
|
with self.assertRaises(AssertionError):
|
|
ReferenceCycleTest().test_has_cycle()
|
|
|
|
ReferenceCycleTest().test_has_no_cycle()
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def test_no_leaked_tensor_decorator(self):
|
|
|
|
class LeakedTensorTest(object):
|
|
|
|
def __init__(inner_self): # pylint: disable=no-self-argument
|
|
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
|
|
|
|
@test_util.assert_no_new_tensors
|
|
def test_has_leak(self):
|
|
self.a = constant_op.constant([3.], name="leak")
|
|
|
|
@test_util.assert_no_new_tensors
|
|
def test_has_no_leak(self):
|
|
constant_op.constant([3.], name="no-leak")
|
|
|
|
with self.assertRaisesRegex(AssertionError, "Tensors not deallocated"):
|
|
LeakedTensorTest().test_has_leak()
|
|
|
|
LeakedTensorTest().test_has_no_leak()
|
|
|
|
def test_no_new_objects_decorator(self):
|
|
|
|
class LeakedObjectTest(unittest.TestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(LeakedObjectTest, self).__init__(*args, **kwargs)
|
|
self.accumulation = []
|
|
|
|
@unittest.expectedFailure
|
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
|
def test_has_leak(self):
|
|
self.accumulation.append([1.])
|
|
|
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
|
def test_has_no_leak(self):
|
|
self.not_accumulating = [1.]
|
|
|
|
self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful())
|
|
self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful())
|
|
|
|
|
|
class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
|
|
parameterized.TestCase):
|
|
@parameterized.named_parameters(
|
|
[("_RunEagerly", True), ("_RunGraph", False)])
|
|
def test_run_functions_eagerly(self, run_eagerly): # pylint: disable=g-wrong-blank-lines
|
|
results = []
|
|
|
|
@def_function.function
|
|
def add_two(x):
|
|
for _ in range(5):
|
|
x += 2
|
|
results.append(x)
|
|
return x
|
|
|
|
with test_util.run_functions_eagerly(run_eagerly):
|
|
add_two(constant_op.constant(2.))
|
|
if context.executing_eagerly():
|
|
if run_eagerly:
|
|
self.assertTrue(isinstance(t, ops.EagerTensor) for t in results)
|
|
else:
|
|
self.assertTrue(isinstance(t, ops.Tensor) for t in results)
|
|
else:
|
|
self.assertTrue(isinstance(t, ops.Tensor) for t in results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
googletest.main()
|