diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index abec3bf58b9..5c78ce3390a 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.compat import compat as forward_compat from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -1425,6 +1424,4 @@ def _has_node_with_op(run_metadata, op_type): if __name__ == "__main__": - # Forward compat date for StatelessIf. - with forward_compat.forward_compatibility_horizon(2019, 7, 23): - test.main() + test.main() diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index ceb83048817..fefeb594bea 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -22,7 +22,6 @@ from absl.testing import parameterized from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.compat import compat from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -241,113 +240,110 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) def testMultipleWhileLoopsWithFunc(self): - if compat.forward_compatible(2019, 8, 23): - x = constant_op.constant(2.) + x = constant_op.constant(2.) - @def_function.function - def Fn(): - ret1 = while_loop_v2( - lambda v: v < 4., - lambda v: v * v, [x], - return_same_structure=False, - name="while_1") # x**2 - ret2 = while_loop_v2( - lambda v: v < 16., - lambda v: v * v, [x], - return_same_structure=False, - name="while_2") # x**4 - return ret1, ret2 + @def_function.function + def Fn(): + ret1 = while_loop_v2( + lambda v: v < 4., + lambda v: v * v, [x], + return_same_structure=False, + name="while_1") # x**2 + ret2 = while_loop_v2( + lambda v: v < 16., + lambda v: v * v, [x], + return_same_structure=False, + name="while_2") # x**4 + return ret1, ret2 - concrete_fn = Fn.get_concrete_function() - while_1 = concrete_fn.graph.get_operation_by_name("while_1") - while_2 = concrete_fn.graph.get_operation_by_name("while_2") - self.assertEqual(while_1.type, "StatelessWhile") - self.assertEqual(while_2.type, "StatelessWhile") - self.assertEmpty(while_1.control_inputs) - self.assertEmpty(while_2.control_inputs) + concrete_fn = Fn.get_concrete_function() + while_1 = concrete_fn.graph.get_operation_by_name("while_1") + while_2 = concrete_fn.graph.get_operation_by_name("while_2") + self.assertEqual(while_1.type, "StatelessWhile") + self.assertEqual(while_2.type, "StatelessWhile") + self.assertEmpty(while_1.control_inputs) + self.assertEmpty(while_2.control_inputs) def testMultipleWhileLoopsWithDeps(self): - if compat.forward_compatible(2019, 8, 23): - x = variables.Variable(2.) - c = constant_op.constant(2.) + x = variables.Variable(2.) + c = constant_op.constant(2.) - @def_function.function - def Fn(): - ret1 = while_loop_v2( - lambda v: v < 4., - lambda v: v * x, [c], - return_same_structure=False, - name="while_1") # 2x - ret2 = while_loop_v2( - lambda v: v < 16., - lambda v: v * x * x, [c], - return_same_structure=False, - name="while_2") # 4x - return ret1, ret2 + @def_function.function + def Fn(): + ret1 = while_loop_v2( + lambda v: v < 4., + lambda v: v * x, [c], + return_same_structure=False, + name="while_1") # 2x + ret2 = while_loop_v2( + lambda v: v < 16., + lambda v: v * x * x, [c], + return_same_structure=False, + name="while_2") # 4x + return ret1, ret2 - concrete_fn = Fn.get_concrete_function() - while_1 = concrete_fn.graph.get_operation_by_name("while_1") - while_2 = concrete_fn.graph.get_operation_by_name("while_2") - self.assertEqual(while_1.type, "While") - self.assertEqual(while_2.type, "While") - self.assertEmpty(while_1.control_inputs) - self.assertLen(while_2.control_inputs, 1) - self.assertIs(while_2.control_inputs[0], while_1) + concrete_fn = Fn.get_concrete_function() + while_1 = concrete_fn.graph.get_operation_by_name("while_1") + while_2 = concrete_fn.graph.get_operation_by_name("while_2") + self.assertEqual(while_1.type, "While") + self.assertEqual(while_2.type, "While") + self.assertEmpty(while_1.control_inputs) + self.assertLen(while_2.control_inputs, 1) + self.assertIs(while_2.control_inputs[0], while_1) def testMultipleWhileLoopsWithVarsDeps(self): - if compat.forward_compatible(2019, 8, 23): - x1 = variables.Variable(2.) - x2 = variables.Variable(3.) - c = constant_op.constant(2.) + x1 = variables.Variable(2.) + x2 = variables.Variable(3.) + c = constant_op.constant(2.) - @def_function.function - def Fn(): - ret1 = while_loop_v2( - lambda v: v < 4., - lambda v: v * x1, [c], - return_same_structure=False, - name="while_1") # 2x - ret2 = while_loop_v2( - lambda v: v < 16., - lambda v: v * x1 * x1, [c], - return_same_structure=False, - name="while_2") # 4x - ret3 = while_loop_v2( - lambda v: v < 4., - lambda v: v * x2, [c], - return_same_structure=False, - name="while_3") # 3x - ret4 = while_loop_v2( - lambda v: v < 16., - lambda v: v * x2 * x2, [c], - return_same_structure=False, - name="while_4") # 9x - ret5 = while_loop_v2( - lambda v: v < 16., - lambda v: v * v, [c], - return_same_structure=False, - name="while_stateless") # x**2 - return ret1, ret2, ret3, ret4, ret5 + @def_function.function + def Fn(): + ret1 = while_loop_v2( + lambda v: v < 4., + lambda v: v * x1, [c], + return_same_structure=False, + name="while_1") # 2x + ret2 = while_loop_v2( + lambda v: v < 16., + lambda v: v * x1 * x1, [c], + return_same_structure=False, + name="while_2") # 4x + ret3 = while_loop_v2( + lambda v: v < 4., + lambda v: v * x2, [c], + return_same_structure=False, + name="while_3") # 3x + ret4 = while_loop_v2( + lambda v: v < 16., + lambda v: v * x2 * x2, [c], + return_same_structure=False, + name="while_4") # 9x + ret5 = while_loop_v2( + lambda v: v < 16., + lambda v: v * v, [c], + return_same_structure=False, + name="while_stateless") # x**2 + return ret1, ret2, ret3, ret4, ret5 - concrete_fn = Fn.get_concrete_function() - while_1 = concrete_fn.graph.get_operation_by_name("while_1") - while_2 = concrete_fn.graph.get_operation_by_name("while_2") - while_3 = concrete_fn.graph.get_operation_by_name("while_3") - while_4 = concrete_fn.graph.get_operation_by_name("while_4") - while_stateless = concrete_fn.graph.get_operation_by_name( - "while_stateless") - self.assertEqual(while_1.type, "While") - self.assertEqual(while_2.type, "While") - self.assertEqual(while_3.type, "While") - self.assertEqual(while_4.type, "While") - self.assertEqual(while_stateless.type, "StatelessWhile") - self.assertEmpty(while_1.control_inputs) - self.assertLen(while_2.control_inputs, 1) - self.assertIs(while_2.control_inputs[0], while_1) - self.assertEmpty(while_3.control_inputs) - self.assertLen(while_4.control_inputs, 1) - self.assertIs(while_4.control_inputs[0], while_3) - self.assertEmpty(while_stateless.control_inputs) + concrete_fn = Fn.get_concrete_function() + while_1 = concrete_fn.graph.get_operation_by_name("while_1") + while_2 = concrete_fn.graph.get_operation_by_name("while_2") + while_3 = concrete_fn.graph.get_operation_by_name("while_3") + while_4 = concrete_fn.graph.get_operation_by_name("while_4") + while_stateless = concrete_fn.graph.get_operation_by_name( + "while_stateless") + self.assertEqual(while_1.type, "While") + self.assertEqual(while_2.type, "While") + self.assertEqual(while_3.type, "While") + self.assertEqual(while_4.type, "While") + self.assertEqual(while_stateless.type, "StatelessWhile") + self.assertEmpty(while_1.control_inputs) + self.assertLen(while_2.control_inputs, 1) + self.assertIs(while_2.control_inputs[0], while_1) + self.assertEmpty(while_3.control_inputs) + self.assertLen(while_4.control_inputs, 1) + self.assertIs(while_4.control_inputs[0], while_3) + self.assertEmpty(while_stateless.control_inputs) @test_util.run_deprecated_v1 def testDoubleDerivative(self): @@ -804,8 +800,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): lambda i: i + 1, [constant_op.constant(0)], return_same_structure=False) while_op = output.op.inputs[0].op - if compat.forward_compatible(2019, 8, 23): - self.assertEqual(while_op.type, "StatelessWhile") + self.assertEqual(while_op.type, "StatelessWhile") return while_op def testDefaultName(self): @@ -886,24 +881,23 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): @test_util.run_deprecated_v1 def testForwardPassRewrite(self): - if compat.forward_compatible(2019, 8, 23): - x = constant_op.constant(1.0, name="x") - output = while_v2.while_loop(lambda x: x < 10.0, - lambda x: x * 2.0, - [x])[0] - while_op = output.op.inputs[0].op - self.assertEqual(while_op.type, "StatelessWhile") - # outputs = [loop_counter, max_iters, x] - self.assertLen(while_op.outputs, 3) + x = constant_op.constant(1.0, name="x") + output = while_v2.while_loop(lambda x: x < 10.0, + lambda x: x * 2.0, + [x])[0] + while_op = output.op.inputs[0].op + self.assertEqual(while_op.type, "StatelessWhile") + # outputs = [loop_counter, max_iters, x] + self.assertLen(while_op.outputs, 3) - gradients_impl.gradients(output, x) - # while_op should have been rewritten to output intermediates. - # outputs = [loop_counter, max_iters, x, x_accumulator] - self.assertLen(while_op.outputs, 4) + gradients_impl.gradients(output, x) + # while_op should have been rewritten to output intermediates. + # outputs = [loop_counter, max_iters, x, x_accumulator] + self.assertLen(while_op.outputs, 4) - gradients_impl.gradients(output, x) - # Computing the gradient again shouldn't rewrite while_op again. - self.assertLen(while_op.outputs, 4) + gradients_impl.gradients(output, x) + # Computing the gradient again shouldn't rewrite while_op again. + self.assertLen(while_op.outputs, 4) @parameterized.named_parameters( ("RandomUniform", random_ops.random_uniform, [5, 3]), diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 14cb5d29f7f..65adf2288f1 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -25,7 +25,6 @@ from __future__ import print_function import collections -from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import function_def_to_graph @@ -256,11 +255,7 @@ def _build_cond(pred, false_stateful_ops = [ op for op in false_graph.get_operations() if op._is_stateful ] - # TODO(srbs): Remove this after July 22, 2019. This is required to abide by - # 3-week forward compat window of new TF python op generating code with - # stale runtime binaries. - if (true_stateful_ops or false_stateful_ops or - not compat.forward_compatible(2019, 7, 22)): + if (true_stateful_ops or false_stateful_ops): op_fn = gen_functional_ops._if else: op_fn = gen_functional_ops.stateless_if diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 47508873009..396484f6bc8 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -24,7 +24,6 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph as func_graph_module @@ -277,11 +276,7 @@ def while_loop(cond, body_stateful_ops = [ op for op in body_graph.get_operations() if op._is_stateful ] - # TODO(yanhuasun): Remove this after Aug 23, 2019. This is required to - # abide by 3-week forward compat window of new TF python op generating - # code with stale runtime binaries. - if (cond_stateful_ops or body_stateful_ops or - not compat.forward_compatible(2019, 8, 23)): + if (cond_stateful_ops or body_stateful_ops): op_fn = gen_functional_ops._while else: op_fn = gen_functional_ops.stateless_while