Add support for StatelessWhile.
PiperOrigin-RevId: 260768101
This commit is contained in:
parent
c7933ce9f3
commit
0b9d4fd92d
@ -33,7 +33,8 @@ from tensorflow.python.training.saver import export_meta_graph
|
||||
|
||||
|
||||
_CONDITIONAL_OPS = set(["If", "StatelessIf"])
|
||||
_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(set(["While"]))
|
||||
_LOOP_OPS = set(["While", "StatelessWhile"])
|
||||
_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)
|
||||
|
||||
|
||||
def disable_lower_using_switch_merge(graph_def):
|
||||
@ -202,10 +203,9 @@ def _get_control_flow_function_data(node_defs, tensor_data):
|
||||
|
||||
Creates a map from function name to a list of types and a list of shapes that
|
||||
correspond with the function arguments. The data is primarily determined from
|
||||
the corresponding "If", "StatelessIf", or "While" op. If the argument is a
|
||||
resource variable, then the type is determined from the type of the data
|
||||
contained within the Tensor. The shape data is only determined in the case of
|
||||
the "While" op.
|
||||
the corresponding "If" or "While" op. If the argument is a resource variable,
|
||||
then the type is determined from the type of the data contained within the
|
||||
Tensor. The shape data is only determined in the case of the "While" op.
|
||||
|
||||
`is_also_output_type` is used to identify the "While" bodies that require the
|
||||
output types to be updated at the same time the input types are updated.
|
||||
@ -249,7 +249,7 @@ def _get_control_flow_function_data(node_defs, tensor_data):
|
||||
|
||||
add_value(node.attr["then_branch"].func.name, arg_types, None, False)
|
||||
add_value(node.attr["else_branch"].func.name, arg_types, None, False)
|
||||
elif node.op == "While":
|
||||
elif node.op in _LOOP_OPS:
|
||||
arg_types = [dtype for dtype in node.attr["T"].list.type]
|
||||
output_shapes = [shape for shape in node.attr["output_shapes"].list.shape]
|
||||
|
||||
@ -298,7 +298,7 @@ def _populate_identity_op(output_node, input_node):
|
||||
|
||||
|
||||
def _populate_if_op(output_node, input_node, function_data):
|
||||
"""Updates the type attributes and the function names of If or StatelessIf.
|
||||
"""Updates the type attributes and function names of If or StatelessIf.
|
||||
|
||||
Args:
|
||||
output_node: TensorFlow NodeDef.
|
||||
@ -317,7 +317,7 @@ def _populate_if_op(output_node, input_node, function_data):
|
||||
|
||||
|
||||
def _populate_while_op(output_node, input_node, function_data):
|
||||
"""Updates the type attributes and the function names of the While op.
|
||||
"""Updates the type attributes and function names of While or StatelessWhile.
|
||||
|
||||
Args:
|
||||
output_node: TensorFlow NodeDef.
|
||||
@ -432,7 +432,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
|
||||
if input_name in tensor_data:
|
||||
dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
|
||||
_save_placeholder(_get_tensor_name(input_tensor), dtype)
|
||||
elif node.op == "While":
|
||||
elif node.op in _LOOP_OPS:
|
||||
# Get dtype and data for resource Placeholders.
|
||||
cond_func = node.attr["cond"].func.name
|
||||
arg_types = function_data[cond_func]["types"]
|
||||
@ -442,7 +442,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
|
||||
dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
|
||||
_save_placeholder(_get_tensor_name(input_tensor), dtype)
|
||||
elif (node.op == "Identity" and node.attr["T"].type == dtypes.resource and
|
||||
name_to_node[_get_tensor_name(node.input[0])].op == "While"):
|
||||
name_to_node[_get_tensor_name(node.input[0])].op in _LOOP_OPS):
|
||||
# Store the dtype for Identity resource ops that are outputs of While ops.
|
||||
while_node = name_to_node[_get_tensor_name(node.input[0])]
|
||||
body_func = while_node.attr["body"].func.name
|
||||
@ -502,7 +502,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
|
||||
# Update the function names and argument types for the conditional ops.
|
||||
elif input_node.op in _CONDITIONAL_OPS:
|
||||
_populate_if_op(output_node, input_node, function_data)
|
||||
elif input_node.op == "While":
|
||||
elif input_node.op in _LOOP_OPS:
|
||||
_populate_while_op(output_node, input_node, function_data)
|
||||
else:
|
||||
output_node.CopyFrom(input_node)
|
||||
@ -553,7 +553,7 @@ def convert_variables_to_constants_v2(func, lower_control_flow=True):
|
||||
# Update the function names and argument types for the conditional ops.
|
||||
elif input_node.op in _CONDITIONAL_OPS:
|
||||
_populate_if_op(output_node, input_node, function_data)
|
||||
elif input_node.op == "While":
|
||||
elif input_node.op in _LOOP_OPS:
|
||||
_populate_while_op(output_node, input_node, function_data)
|
||||
else:
|
||||
output_node.CopyFrom(input_node)
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import while_v2
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import simple_save
|
||||
from tensorflow.python.saved_model.load import load
|
||||
@ -350,7 +351,7 @@ class VariablesToConstantsTest(test.TestCase):
|
||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testLoop(self):
|
||||
def testWhile(self):
|
||||
"""Test a While loop."""
|
||||
input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])}
|
||||
|
||||
@ -371,6 +372,24 @@ class VariablesToConstantsTest(test.TestCase):
|
||||
root, output_func = self._freezeModel(model)
|
||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testStatelessWhile(self):
|
||||
"""Test a StatelessWhile loop."""
|
||||
input_data = {"x": constant_op.constant(2.)}
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
|
||||
])
|
||||
def model(x):
|
||||
return while_v2.while_loop(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_1") # x**2
|
||||
|
||||
root, output_func = self._freezeModel(model)
|
||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testDynamicRnn(self):
|
||||
"""Test a DynamicRnn containing While loops."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user