diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 4bf5095ae8b..5f90ed1f202 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1965,17 +1965,17 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(len(add_executions), 2) @def_function.function - def fn(): - a = constant_op.constant(1) - b = constant_op.constant(2) + def fn(a, b): c = a + b - d = a + b + # These two AddV2 cannot use the same argument in tf.function since an + # optimization pass will remove duplicate ops and only run it once. + d = a + c return c, d with CaptureStderr() as log: - c, d = self.evaluate(fn()) + c, d = self.evaluate(fn(constant_op.constant(1), constant_op.constant(2))) self.assertEqual(c, 3) - self.assertEqual(d, 3) + self.assertEqual(d, 4) # Ensure that we did log device placement. add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] self.assertEqual(len(add_executions), 2)