Add support for StatelessWhile.

PiperOrigin-RevId: 260768101
This commit is contained in:
Nupur Garg 2019-07-30 12:12:05 -07:00 committed by TensorFlower Gardener
parent c7933ce9f3
commit 0b9d4fd92d
2 changed files with 32 additions and 13 deletions

View File

@ -33,7 +33,8 @@ from tensorflow.python.training.saver import export_meta_graph
_CONDITIONAL_OPS = set(["If", "StatelessIf"])
_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(set(["While"]))
_LOOP_OPS = set(["While", "StatelessWhile"])
_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)
def disable_lower_using_switch_merge(graph_def):
@ -202,10 +203,9 @@ def _get_control_flow_function_data(node_defs, tensor_data):
Creates a map from function name to a list of types and a list of shapes that
correspond with the function arguments. The data is primarily determined from
the corresponding "If", "StatelessIf", or "While" op. If the argument is a
resource variable, then the type is determined from the type of the data
contained within the Tensor. The shape data is only determined in the case of
the "While" op.
the corresponding "If" or "While" op. If the argument is a resource variable,
then the type is determined from the type of the data contained within the
Tensor. The shape data is only determined in the case of the "While" op.
`is_also_output_type` is used to identify the "While" bodies that require the
output types to be updated at the same time the input types are updated.
@ -249,7 +249,7 @@ def _get_control_flow_function_data(node_defs, tensor_data):
add_value(node.attr["then_branch"].func.name, arg_types, None, False)
add_value(node.attr["else_branch"].func.name, arg_types, None, False)
elif node.op == "While":
elif node.op in _LOOP_OPS:
arg_types = [dtype for dtype in node.attr["T"].list.type]
output_shapes = [shape for shape in node.attr["output_shapes"].list.shape]
@ -298,7 +298,7 @@ def _populate_identity_op(output_node, input_node):
def _populate_if_op(output_node, input_node, function_data):
"""Updates the type attributes and the function names of If or StatelessIf.
"""Updates the type attributes and function names of If or StatelessIf.
Args:
output_node: TensorFlow NodeDef.
@ -317,7 +317,7 @@ def _populate_if_op(output_node, input_node, function_data):
def _populate_while_op(output_node, input_node, function_data):
"""Updates the type attributes and the function names of the While op.
"""Updates the type attributes and function names of While or StatelessWhile.
Args:
output_node: TensorFlow NodeDef.
@ -432,7 +432,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
if input_name in tensor_data:
dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
_save_placeholder(_get_tensor_name(input_tensor), dtype)
elif node.op == "While":
elif node.op in _LOOP_OPS:
# Get dtype and data for resource Placeholders.
cond_func = node.attr["cond"].func.name
arg_types = function_data[cond_func]["types"]
@ -442,7 +442,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
_save_placeholder(_get_tensor_name(input_tensor), dtype)
elif (node.op == "Identity" and node.attr["T"].type == dtypes.resource and
name_to_node[_get_tensor_name(node.input[0])].op == "While"):
name_to_node[_get_tensor_name(node.input[0])].op in _LOOP_OPS):
# Store the dtype for Identity resource ops that are outputs of While ops.
while_node = name_to_node[_get_tensor_name(node.input[0])]
body_func = while_node.attr["body"].func.name
@ -502,7 +502,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
# Update the function names and argument types for the conditional ops.
elif input_node.op in _CONDITIONAL_OPS:
_populate_if_op(output_node, input_node, function_data)
elif input_node.op == "While":
elif input_node.op in _LOOP_OPS:
_populate_while_op(output_node, input_node, function_data)
else:
output_node.CopyFrom(input_node)
@ -553,7 +553,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
# Update the function names and argument types for the conditional ops.
elif input_node.op in _CONDITIONAL_OPS:
_populate_if_op(output_node, input_node, function_data)
elif input_node.op == "While":
elif input_node.op in _LOOP_OPS:
_populate_while_op(output_node, input_node, function_data)
else:
output_node.CopyFrom(input_node)

View File

@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_v2
from tensorflow.python.platform import test
from tensorflow.python.saved_model import simple_save
from tensorflow.python.saved_model.load import load
@ -350,7 +351,7 @@ class VariablesToConstantsTest(test.TestCase):
self._testConvertedFunction(root, root.f, output_func, input_data)
@test_util.run_v2_only
def testLoop(self):
def testWhile(self):
"""Test a While loop."""
input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])}
@ -371,6 +372,24 @@ class VariablesToConstantsTest(test.TestCase):
root, output_func = self._freezeModel(model)
self._testConvertedFunction(root, root.f, output_func, input_data)
@test_util.run_v2_only
def testStatelessWhile(self):
"""Test a StatelessWhile loop."""
input_data = {"x": constant_op.constant(2.)}
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
])
def model(x):
return while_v2.while_loop(
lambda v: v < 4.,
lambda v: v * v, [x],
return_same_structure=False,
name="while_1") # x**2
root, output_func = self._freezeModel(model)
self._testConvertedFunction(root, root.f, output_func, input_data)
@test_util.run_v2_only
def testDynamicRnn(self):
"""Test a DynamicRnn containing While loops."""