Remove compatibility check for if and while
PiperOrigin-RevId: 265757988
This commit is contained in:
parent
e9c7e5357a
commit
c7f4fe381a
@ -20,7 +20,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.compat import compat as forward_compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -1425,6 +1424,4 @@ def _has_node_with_op(run_metadata, op_type):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Forward compat date for StatelessIf.
|
||||
with forward_compat.forward_compatibility_horizon(2019, 7, 23):
|
||||
test.main()
|
||||
test.main()
|
||||
|
@ -22,7 +22,6 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -241,113 +240,110 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
||||
|
||||
def testMultipleWhileLoopsWithFunc(self):
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x = constant_op.constant(2.)
|
||||
x = constant_op.constant(2.)
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_1") # x**2
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_2") # x**4
|
||||
return ret1, ret2
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_1") # x**2
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_2") # x**4
|
||||
return ret1, ret2
|
||||
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
self.assertEqual(while_1.type, "StatelessWhile")
|
||||
self.assertEqual(while_2.type, "StatelessWhile")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertEmpty(while_2.control_inputs)
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
self.assertEqual(while_1.type, "StatelessWhile")
|
||||
self.assertEqual(while_2.type, "StatelessWhile")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertEmpty(while_2.control_inputs)
|
||||
|
||||
def testMultipleWhileLoopsWithDeps(self):
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x = variables.Variable(2.)
|
||||
c = constant_op.constant(2.)
|
||||
x = variables.Variable(2.)
|
||||
c = constant_op.constant(2.)
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x, [c],
|
||||
return_same_structure=False,
|
||||
name="while_1") # 2x
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x * x, [c],
|
||||
return_same_structure=False,
|
||||
name="while_2") # 4x
|
||||
return ret1, ret2
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x, [c],
|
||||
return_same_structure=False,
|
||||
name="while_1") # 2x
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x * x, [c],
|
||||
return_same_structure=False,
|
||||
name="while_2") # 4x
|
||||
return ret1, ret2
|
||||
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
self.assertEqual(while_1.type, "While")
|
||||
self.assertEqual(while_2.type, "While")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertLen(while_2.control_inputs, 1)
|
||||
self.assertIs(while_2.control_inputs[0], while_1)
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
self.assertEqual(while_1.type, "While")
|
||||
self.assertEqual(while_2.type, "While")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertLen(while_2.control_inputs, 1)
|
||||
self.assertIs(while_2.control_inputs[0], while_1)
|
||||
|
||||
def testMultipleWhileLoopsWithVarsDeps(self):
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x1 = variables.Variable(2.)
|
||||
x2 = variables.Variable(3.)
|
||||
c = constant_op.constant(2.)
|
||||
x1 = variables.Variable(2.)
|
||||
x2 = variables.Variable(3.)
|
||||
c = constant_op.constant(2.)
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x1, [c],
|
||||
return_same_structure=False,
|
||||
name="while_1") # 2x
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x1 * x1, [c],
|
||||
return_same_structure=False,
|
||||
name="while_2") # 4x
|
||||
ret3 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x2, [c],
|
||||
return_same_structure=False,
|
||||
name="while_3") # 3x
|
||||
ret4 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x2 * x2, [c],
|
||||
return_same_structure=False,
|
||||
name="while_4") # 9x
|
||||
ret5 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * v, [c],
|
||||
return_same_structure=False,
|
||||
name="while_stateless") # x**2
|
||||
return ret1, ret2, ret3, ret4, ret5
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x1, [c],
|
||||
return_same_structure=False,
|
||||
name="while_1") # 2x
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x1 * x1, [c],
|
||||
return_same_structure=False,
|
||||
name="while_2") # 4x
|
||||
ret3 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x2, [c],
|
||||
return_same_structure=False,
|
||||
name="while_3") # 3x
|
||||
ret4 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x2 * x2, [c],
|
||||
return_same_structure=False,
|
||||
name="while_4") # 9x
|
||||
ret5 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * v, [c],
|
||||
return_same_structure=False,
|
||||
name="while_stateless") # x**2
|
||||
return ret1, ret2, ret3, ret4, ret5
|
||||
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
while_3 = concrete_fn.graph.get_operation_by_name("while_3")
|
||||
while_4 = concrete_fn.graph.get_operation_by_name("while_4")
|
||||
while_stateless = concrete_fn.graph.get_operation_by_name(
|
||||
"while_stateless")
|
||||
self.assertEqual(while_1.type, "While")
|
||||
self.assertEqual(while_2.type, "While")
|
||||
self.assertEqual(while_3.type, "While")
|
||||
self.assertEqual(while_4.type, "While")
|
||||
self.assertEqual(while_stateless.type, "StatelessWhile")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertLen(while_2.control_inputs, 1)
|
||||
self.assertIs(while_2.control_inputs[0], while_1)
|
||||
self.assertEmpty(while_3.control_inputs)
|
||||
self.assertLen(while_4.control_inputs, 1)
|
||||
self.assertIs(while_4.control_inputs[0], while_3)
|
||||
self.assertEmpty(while_stateless.control_inputs)
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
while_3 = concrete_fn.graph.get_operation_by_name("while_3")
|
||||
while_4 = concrete_fn.graph.get_operation_by_name("while_4")
|
||||
while_stateless = concrete_fn.graph.get_operation_by_name(
|
||||
"while_stateless")
|
||||
self.assertEqual(while_1.type, "While")
|
||||
self.assertEqual(while_2.type, "While")
|
||||
self.assertEqual(while_3.type, "While")
|
||||
self.assertEqual(while_4.type, "While")
|
||||
self.assertEqual(while_stateless.type, "StatelessWhile")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertLen(while_2.control_inputs, 1)
|
||||
self.assertIs(while_2.control_inputs[0], while_1)
|
||||
self.assertEmpty(while_3.control_inputs)
|
||||
self.assertLen(while_4.control_inputs, 1)
|
||||
self.assertIs(while_4.control_inputs[0], while_3)
|
||||
self.assertEmpty(while_stateless.control_inputs)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDoubleDerivative(self):
|
||||
@ -804,8 +800,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
lambda i: i + 1, [constant_op.constant(0)],
|
||||
return_same_structure=False)
|
||||
while_op = output.op.inputs[0].op
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
self.assertEqual(while_op.type, "StatelessWhile")
|
||||
self.assertEqual(while_op.type, "StatelessWhile")
|
||||
return while_op
|
||||
|
||||
def testDefaultName(self):
|
||||
@ -886,24 +881,23 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testForwardPassRewrite(self):
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
output = while_v2.while_loop(lambda x: x < 10.0,
|
||||
lambda x: x * 2.0,
|
||||
[x])[0]
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "StatelessWhile")
|
||||
# outputs = [loop_counter, max_iters, x]
|
||||
self.assertLen(while_op.outputs, 3)
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
output = while_v2.while_loop(lambda x: x < 10.0,
|
||||
lambda x: x * 2.0,
|
||||
[x])[0]
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "StatelessWhile")
|
||||
# outputs = [loop_counter, max_iters, x]
|
||||
self.assertLen(while_op.outputs, 3)
|
||||
|
||||
gradients_impl.gradients(output, x)
|
||||
# while_op should have been rewritten to output intermediates.
|
||||
# outputs = [loop_counter, max_iters, x, x_accumulator]
|
||||
self.assertLen(while_op.outputs, 4)
|
||||
gradients_impl.gradients(output, x)
|
||||
# while_op should have been rewritten to output intermediates.
|
||||
# outputs = [loop_counter, max_iters, x, x_accumulator]
|
||||
self.assertLen(while_op.outputs, 4)
|
||||
|
||||
gradients_impl.gradients(output, x)
|
||||
# Computing the gradient again shouldn't rewrite while_op again.
|
||||
self.assertLen(while_op.outputs, 4)
|
||||
gradients_impl.gradients(output, x)
|
||||
# Computing the gradient again shouldn't rewrite while_op again.
|
||||
self.assertLen(while_op.outputs, 4)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("RandomUniform", random_ops.random_uniform, [5, 3]),
|
||||
|
@ -25,7 +25,6 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import function_def_to_graph
|
||||
@ -256,11 +255,7 @@ def _build_cond(pred,
|
||||
false_stateful_ops = [
|
||||
op for op in false_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
# TODO(srbs): Remove this after July 22, 2019. This is required to abide by
|
||||
# 3-week forward compat window of new TF python op generating code with
|
||||
# stale runtime binaries.
|
||||
if (true_stateful_ops or false_stateful_ops or
|
||||
not compat.forward_compatible(2019, 7, 22)):
|
||||
if (true_stateful_ops or false_stateful_ops):
|
||||
op_fn = gen_functional_ops._if
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_if
|
||||
|
@ -24,7 +24,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
@ -277,11 +276,7 @@ def while_loop(cond,
|
||||
body_stateful_ops = [
|
||||
op for op in body_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
# TODO(yanhuasun): Remove this after Aug 23, 2019. This is required to
|
||||
# abide by 3-week forward compat window of new TF python op generating
|
||||
# code with stale runtime binaries.
|
||||
if (cond_stateful_ops or body_stateful_ops or
|
||||
not compat.forward_compatible(2019, 8, 23)):
|
||||
if (cond_stateful_ops or body_stateful_ops):
|
||||
op_fn = gen_functional_ops._while
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_while
|
||||
|
Loading…
Reference in New Issue
Block a user