diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 71afbd24d8d..e0854b06321 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -428,20 +428,21 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.evaluate(variables.global_variables_initializer()) self.assertEqual(self.evaluate(value), 2.0) - @test_util.run_in_graph_and_eager_modes + @test_util.also_run_as_tf_function def testInitScopeTensorInitializationInFunction(self): @def_function.function def tensor_init(): with ops.init_scope(): const = constant_op.constant(2.0) + # Note: this variable bypasses tf.function's variable creation + # requirements by bypassing variable_creator_scope by using + # ResourceVariable instead of Variable. self.v = resource_variable_ops.ResourceVariable(const) return self.v.read_value() value = tensor_init() - if not context.executing_eagerly(): - self.evaluate(variables.global_variables_initializer()) - self.assertEqual(self.evaluate(value), 2.0) + self.assertAllEqual(value, 2.0) def testDefunShapeInferenceWithCapturedResourceVariable(self): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index b0c3c9b5068..06316ce2e99 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -54,6 +54,7 @@ from tensorflow.python import tf2 from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import tape from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes @@ -67,6 +68,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import script_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -76,6 +78,7 @@ from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import memory from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export @@ -1009,6 +1012,58 @@ def run_in_graph_and_eager_modes(func=None, return decorator +def py_func_if_in_function(f): + + def decorated(*args, **kwds): + if not ops.get_default_graph()._building_function: + return f(*args, **kwds) + + tensor_args, tensor_indices = zip( + *[(x, i) for i, x in enumerate(args) + if isinstance(x, (ops.Tensor, variables.Variable))]) + + def inner_f(*inner_tensor_args): + my_args = list(args) + for i, n in zip(tensor_indices, inner_tensor_args): + my_args[i] = n + return f(*my_args, **kwds) + + return script_ops.py_func(inner_f, tensor_args, []) + + return tf_decorator.make_decorator(f, decorated) + + +def also_run_as_tf_function(f): + """Runs the decorated test twice--once as is, once inside a tf.function. + + This allows you to run a test both in eager execution and inside a + tf.function, exercising the two execution modes supported in tf 2.0. The test + assertions are automatically done inside tf.py_funcs, and tf.function ensures + that they run in the proper order and with the proper side effects. + + Currently variable creation is not supported in tests annotated with this + decorator since it's tricky to ensure the variable doesn't get repeatedly + created when retracing the tf.function. + + Args: + f: the test method to be decorated + + Returns: + The decorated test method, which will run both in eager and inside a + tf.function. + """ + + def decorated(*args, **kwds): + with context.eager_mode(): + # Running in eager mode + f(*args, **kwds) + + defun_f = def_function.function(f) + defun_f(*args, **kwds) + + return decorated + + def run_deprecated_v1(func=None): """Execute the decorated test in graph mode. @@ -1783,8 +1838,8 @@ class TensorFlowTestCase(googletest.TestCase): return ret -# pylint: enable=invalid-name - + # pylint: enable=invalid-name + @py_func_if_in_function def assertNear(self, f1, f2, err, msg=None): """Asserts that two floats are near each other. @@ -1803,6 +1858,7 @@ class TensorFlowTestCase(googletest.TestCase): "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg if msg is not None else "")) + @py_func_if_in_function def assertArrayNear(self, farray1, farray2, err, msg=None): """Asserts that two float arrays are near each other. @@ -1822,6 +1878,7 @@ class TensorFlowTestCase(googletest.TestCase): def _NDArrayNear(self, ndarray1, ndarray2, err): return np.linalg.norm(ndarray1 - ndarray2) < err + @py_func_if_in_function def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): """Asserts that two numpy arrays have near values. @@ -1959,6 +2016,7 @@ class TensorFlowTestCase(googletest.TestCase): e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) raise + @py_func_if_in_function def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): """Asserts that two structures of numpy arrays or Tensors, have near values. @@ -1984,6 +2042,7 @@ class TensorFlowTestCase(googletest.TestCase): """ self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) + @py_func_if_in_function def assertAllCloseAccordingToType(self, a, b, @@ -2031,6 +2090,7 @@ class TensorFlowTestCase(googletest.TestCase): self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) + @py_func_if_in_function def assertNotAllClose(self, a, b, **kwargs): """Assert that two numpy arrays, or or Tensors, do not have near values. @@ -2049,6 +2109,7 @@ class TensorFlowTestCase(googletest.TestCase): return raise AssertionError("The two values are close at all elements") + @py_func_if_in_function def assertAllEqual(self, a, b, msg=None): """Asserts that two numpy arrays or Tensors have the same values. @@ -2091,6 +2152,7 @@ class TensorFlowTestCase(googletest.TestCase): msgs.append("not equal rhs = {}".format(y)) np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) + @py_func_if_in_function def assertAllGreater(self, a, comparison_target): """Assert element values are all greater than a target value. @@ -2102,6 +2164,7 @@ class TensorFlowTestCase(googletest.TestCase): a = self._GetNdArray(a) self.assertGreater(np.min(a), comparison_target) + @py_func_if_in_function def assertAllLess(self, a, comparison_target): """Assert element values are all less than a target value. @@ -2113,6 +2176,7 @@ class TensorFlowTestCase(googletest.TestCase): a = self._GetNdArray(a) self.assertLess(np.max(a), comparison_target) + @py_func_if_in_function def assertAllGreaterEqual(self, a, comparison_target): """Assert element values are all greater than or equal to a target value. @@ -2124,6 +2188,7 @@ class TensorFlowTestCase(googletest.TestCase): a = self._GetNdArray(a) self.assertGreaterEqual(np.min(a), comparison_target) + @py_func_if_in_function def assertAllLessEqual(self, a, comparison_target): """Assert element values are all less than or equal to a target value. @@ -2166,6 +2231,7 @@ class TensorFlowTestCase(googletest.TestCase): lines.append(prefix + "...") return lines + @py_func_if_in_function def assertAllInRange(self, target, lower_bound, @@ -2224,6 +2290,7 @@ class TensorFlowTestCase(googletest.TestCase): "Subscript(s) and value(s) of the offending elements:\n" + "\n".join(self._format_subscripts(violation_subscripts, target))) + @py_func_if_in_function def assertAllInSet(self, target, expected_set): """Assert that elements of a Tensor are all in a given closed set. @@ -2245,6 +2312,7 @@ class TensorFlowTestCase(googletest.TestCase): raise AssertionError("%d unique element(s) are not in the set %s: %s" % (np.size(diff), expected_set, diff)) + @py_func_if_in_function def assertDTypeEqual(self, target, expected_dtype): """Assert ndarray data type is equal to expected. diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 44d4bd5e30f..451eb385306 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -237,7 +237,8 @@ class VariableScopeTest(test.TestCase): _ = d2(x) self.assertEqual(len(d2.variables), 2) v3, v4 = d2.variables - self.assertAllEqual([v1, v2], [v3, v4]) + self.assertEqual(v1, v3) + self.assertEqual(v2, v4) f() # TODO(mihaimaruseac): Not converted to use wrap_function because of @@ -1684,7 +1685,7 @@ class VariableScopeWithCustomGetterTest(test.TestCase): with variable_scope.variable_creator_scope(creator_b): variable_scope.variable(1.0, name="one_name") - self.assertAllEqual(variable_names, ["forced_name"]) + self.assertEqual(variable_names[0], "forced_name") called = [False]