In while_v2 emit a StatelessIf op if the body is stateless.
PiperOrigin-RevId: 260927755
This commit is contained in:
parent
719ad3bfde
commit
170a95de67
@ -318,7 +318,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.type_string() == "While" &&
|
||||
if (node.IsWhileNode() &&
|
||||
!IsCompilableWhile(node, lib_runtime, stack_trace, uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported while");
|
||||
return false;
|
||||
|
@ -440,7 +440,7 @@ Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
|
||||
n->ClearAttr(attr_name);
|
||||
n->AddAttr(attr_name, branch_func);
|
||||
}
|
||||
} else if (n->type_string() == "While") {
|
||||
} else if (n->IsWhileNode()) {
|
||||
for (const string& attr_name : std::vector<string>{"cond", "body"}) {
|
||||
NameAttrList branch_func;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
|
||||
@ -595,7 +595,7 @@ void ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
Status PostprocessLiftedArgsForWhile(
|
||||
const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld) {
|
||||
TF_RET_CHECK(n->type_string() == "While");
|
||||
TF_RET_CHECK(n->IsWhileNode());
|
||||
|
||||
// Check if there is any lifted args in body function.
|
||||
NameAttrList body_func;
|
||||
@ -936,7 +936,7 @@ Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (n->type_string() == "While") {
|
||||
if (n->IsWhileNode()) {
|
||||
TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
|
||||
outside_compilation_attr_to_node, g, n, fld));
|
||||
}
|
||||
@ -1782,7 +1782,7 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->IsIfNode()) {
|
||||
if_nodes.push_back(n);
|
||||
} else if (n->type_string() == "While") {
|
||||
} else if (n->IsWhileNode()) {
|
||||
while_nodes.push_back(n);
|
||||
} else if (IsFunctionCall(*fld, *n)) {
|
||||
func_call_nodes.push_back(n);
|
||||
|
@ -970,8 +970,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
int effective_cluster_size =
|
||||
(node->IsIdentity() || node->IsConstant()) ? 0 : 1;
|
||||
|
||||
bool has_functional_control_flow =
|
||||
node->type_string() == "While" || node->IsIfNode();
|
||||
bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
|
||||
|
||||
absl::optional<DeadnessPredicate> deadness_predicate;
|
||||
if (deadness_analysis_) {
|
||||
|
@ -91,7 +91,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
|
||||
FunctionLibraryRuntime* flib_runtime) {
|
||||
DCHECK(op_def != nullptr || op_kernel != nullptr);
|
||||
// TODO(b/124403063): Implement similar functionality for function call nodes.
|
||||
if (node.op() == "While") {
|
||||
if (node.op() == "While" || node.op() == "StatelessWhile") {
|
||||
// For While nodes, recurse into the body and cond graphs.
|
||||
const FunctionBody* fcond = nullptr;
|
||||
const FunctionBody* fbody = nullptr;
|
||||
|
@ -527,7 +527,7 @@ Status RearrangeFunctionArguments(
|
||||
|
||||
// Rewrite If/While nodes.
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->type_string() == "While") {
|
||||
if (n->IsWhileNode()) {
|
||||
bool node_rewritten;
|
||||
TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld,
|
||||
&node_rewritten));
|
||||
|
@ -50,7 +50,7 @@ Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) {
|
||||
node->ClearAttr(attr_name);
|
||||
node->AddAttr(attr_name, branch_func);
|
||||
}
|
||||
} else if (node->type_string() == "While") {
|
||||
} else if (node->IsWhileNode()) {
|
||||
AttrValue device_ordinal_value;
|
||||
device_ordinal_value.set_i(device_ordinal);
|
||||
for (const string& attr_name : std::vector<string>{"cond", "body"}) {
|
||||
|
@ -765,7 +765,7 @@ Status PropagateConstIntoFunctionalNodes(
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (n->IsIfNode()) {
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
|
||||
} else if (n->type_string() == "While") {
|
||||
} else if (n->IsWhileNode()) {
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
|
||||
}
|
||||
}
|
||||
@ -796,7 +796,7 @@ Status RewriteTensorListWithConstElement(Graph* g,
|
||||
// Find the forward While op.
|
||||
std::vector<const Edge*> fwd_while_edges;
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
if (!e->IsControlEdge() && e->dst()->type_string() == "While") {
|
||||
if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
|
||||
fwd_while_edges.push_back(e);
|
||||
}
|
||||
}
|
||||
@ -810,8 +810,7 @@ Status RewriteTensorListWithConstElement(Graph* g,
|
||||
int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
|
||||
std::vector<const Edge*> bwd_while_edges;
|
||||
for (const Edge* e : fwd_while->out_edges()) {
|
||||
if (e->src_output() == fwd_while_dst_input &&
|
||||
e->dst()->type_string() == "While") {
|
||||
if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
|
||||
bwd_while_edges.push_back(e);
|
||||
}
|
||||
}
|
||||
|
@ -143,7 +143,7 @@ Status LowerFunctionalOpsPass::Run(
|
||||
} else if (n->type_string() == "Case") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteCaseNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||
} else if (n->type_string() == "While") {
|
||||
} else if (n->IsWhileNode()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||
} else {
|
||||
|
@ -90,6 +90,8 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
||||
{"StatefulPartitionedCall", NC_PARTITIONED_CALL},
|
||||
{"If", NC_IF},
|
||||
{"StatelessIf", NC_IF},
|
||||
{"While", NC_WHILE},
|
||||
{"StatelessWhile", NC_WHILE},
|
||||
// Not using the constants defined in FunctionLibraryDefinition for the
|
||||
// 4 ops below because android inference library does not link
|
||||
// tf.function related files.
|
||||
@ -592,7 +594,7 @@ Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
|
||||
}
|
||||
|
||||
Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
|
||||
if (dst->type_string() != "While") {
|
||||
if (!dst->IsWhileNode()) {
|
||||
return errors::Internal(
|
||||
"dst argument to AddWhileEdgeHack should be a While op, got: ",
|
||||
dst->DebugString());
|
||||
|
@ -177,6 +177,7 @@ class Node {
|
||||
bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
|
||||
bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
|
||||
bool IsIfNode() const { return class_ == NC_IF; }
|
||||
bool IsWhileNode() const { return class_ == NC_WHILE; }
|
||||
// Is this node a function input
|
||||
bool IsArg() const { return class_ == NC_ARG; }
|
||||
// Is this node a function output
|
||||
@ -264,6 +265,7 @@ class Node {
|
||||
NC_FAKE_PARAM,
|
||||
NC_PARTITIONED_CALL,
|
||||
NC_IF,
|
||||
NC_WHILE,
|
||||
NC_ARG,
|
||||
NC_RETVAL,
|
||||
NC_OTHER // Not a special kind of node
|
||||
|
@ -1219,7 +1219,7 @@ Status InlineFunctionCalls(const GrapplerItem& item,
|
||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
|
||||
} else if (n->type_string() == "Case") {
|
||||
TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), flib_def, false));
|
||||
} else if (n->type_string() == "While") {
|
||||
} else if (n->IsWhileNode()) {
|
||||
TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
|
||||
}
|
||||
continue;
|
||||
|
@ -195,6 +195,31 @@ body: A function that takes a list of tensors and returns another
|
||||
by T.
|
||||
)doc");
|
||||
|
||||
Status WhileShapeInferenceFn(shape_inference::InferenceContext* c) {
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
|
||||
// If `output_shapes` attr is set use that as the shapes of the outputs
|
||||
// else use the input shapes.
|
||||
if (!output_shapes.empty()) {
|
||||
if (output_shapes.size() != c->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"`output_shapes` must be the same length as num outputs (",
|
||||
output_shapes.size(), " vs. ", c->num_outputs());
|
||||
}
|
||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||
shape_inference::ShapeHandle output_shape_handle;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
|
||||
output_shapes[i], &output_shape_handle));
|
||||
c->set_output(static_cast<int>(i), output_shape_handle);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->input(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("While")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
@ -204,30 +229,7 @@ REGISTER_OP("While")
|
||||
.Attr("output_shapes: list(shape) = []")
|
||||
.Attr("parallel_iterations: int = 10")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
|
||||
// If `output_shapes` attr is set use that as the shapes of the outputs
|
||||
// else use the input shapes.
|
||||
if (!output_shapes.empty()) {
|
||||
if (output_shapes.size() != c->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"`output_shapes` must be the same length as num outputs (",
|
||||
output_shapes.size(), " vs. ", c->num_outputs());
|
||||
}
|
||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||
shape_inference::ShapeHandle output_shape_handle;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
|
||||
output_shapes[i], &output_shape_handle));
|
||||
c->set_output(static_cast<int>(i), output_shape_handle);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->input(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
.SetShapeFn(WhileShapeInferenceFn);
|
||||
|
||||
REGISTER_OP("StatelessWhile")
|
||||
.Input("input: T")
|
||||
@ -235,12 +237,9 @@ REGISTER_OP("StatelessWhile")
|
||||
.Attr("T: list(type) >= 0")
|
||||
.Attr("cond: func")
|
||||
.Attr("body: func")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->input(i));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
.Attr("output_shapes: list(shape) = []")
|
||||
.Attr("parallel_iterations: int = 10")
|
||||
.SetShapeFn(WhileShapeInferenceFn);
|
||||
|
||||
REGISTER_OP("For")
|
||||
.Input("start: int32")
|
||||
|
@ -80,7 +80,7 @@ class UtilTest(test_util.TensorFlowTestCase):
|
||||
sess.graph_def)
|
||||
lower_using_switch_merge_is_removed = False
|
||||
for node in new_graph_def.node:
|
||||
if node.op == "While":
|
||||
if node.op == "While" or node.op == "StatelessWhile":
|
||||
if not node.attr["_lower_using_switch_merge"].b:
|
||||
lower_using_switch_merge_is_removed = True
|
||||
self.assertEqual(lower_using_switch_merge_is_removed, True)
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat as forward_compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -793,33 +794,34 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
@test_util.enable_control_flow_v2
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testAddWhileInput(self):
|
||||
@eager_function.defun
|
||||
def test():
|
||||
output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
|
||||
[1])
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "While")
|
||||
orig_num_inputs = len(while_op.inputs)
|
||||
if forward_compat.forward_compatible(2019, 8, 23):
|
||||
@eager_function.defun
|
||||
def test():
|
||||
output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
|
||||
[1])
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "StatelessWhile")
|
||||
orig_num_inputs = len(while_op.inputs)
|
||||
|
||||
# Make sure we can handle the while op having a control input.
|
||||
while_op._add_control_input(constant_op.constant(0).op)
|
||||
# Make sure we can handle the while op having a control input.
|
||||
while_op._add_control_input(constant_op.constant(0).op)
|
||||
|
||||
new_input1 = constant_op.constant(1.0)
|
||||
new_input2 = constant_op.constant(True)
|
||||
new_input1 = constant_op.constant(1.0)
|
||||
new_input2 = constant_op.constant(True)
|
||||
|
||||
# Clear output shapes to bypass shape checking.
|
||||
while_op._set_shape_list_attr("output_shapes", [])
|
||||
while_op._set_type_list_attr("T",
|
||||
[t.dtype for t in while_op.inputs] +
|
||||
[new_input1.dtype, new_input2.dtype])
|
||||
# Clear output shapes to bypass shape checking.
|
||||
while_op._set_shape_list_attr("output_shapes", [])
|
||||
while_op._set_type_list_attr("T",
|
||||
[t.dtype for t in while_op.inputs] +
|
||||
[new_input1.dtype, new_input2.dtype])
|
||||
|
||||
while_op._add_while_inputs([new_input1, new_input2])
|
||||
# Can't add an edge beyond what's specified by "T"
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while_op._add_while_inputs([new_input2])
|
||||
self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
|
||||
while_op._add_while_inputs([new_input1, new_input2])
|
||||
# Can't add an edge beyond what's specified by "T"
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while_op._add_while_inputs([new_input2])
|
||||
self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
|
||||
|
||||
test()
|
||||
test()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testOpDef(self):
|
||||
|
@ -22,10 +22,9 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -36,6 +35,8 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
@ -196,6 +197,115 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
||||
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
||||
|
||||
def testMultipleWhileLoopsWithFunc(self):
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x = constant_op.constant(2.)
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_1") # x**2
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * v, [x],
|
||||
return_same_structure=False,
|
||||
name="while_2") # x**4
|
||||
return ret1, ret2
|
||||
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
self.assertEqual(while_1.type, "StatelessWhile")
|
||||
self.assertEqual(while_2.type, "StatelessWhile")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertEmpty(while_2.control_inputs)
|
||||
|
||||
def testMultipleWhileLoopsWithDeps(self):
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x = variables.Variable(2.)
|
||||
c = constant_op.constant(2.)
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x, [c],
|
||||
return_same_structure=False,
|
||||
name="while_1") # 2x
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x * x, [c],
|
||||
return_same_structure=False,
|
||||
name="while_2") # 4x
|
||||
return ret1, ret2
|
||||
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
self.assertEqual(while_1.type, "While")
|
||||
self.assertEqual(while_2.type, "While")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertLen(while_2.control_inputs, 1)
|
||||
self.assertIs(while_2.control_inputs[0], while_1)
|
||||
|
||||
def testMultipleWhileLoopsWithVarsDeps(self):
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x1 = variables.Variable(2.)
|
||||
x2 = variables.Variable(3.)
|
||||
c = constant_op.constant(2.)
|
||||
|
||||
@def_function.function
|
||||
def Fn():
|
||||
ret1 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x1, [c],
|
||||
return_same_structure=False,
|
||||
name="while_1") # 2x
|
||||
ret2 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x1 * x1, [c],
|
||||
return_same_structure=False,
|
||||
name="while_2") # 4x
|
||||
ret3 = while_loop_v2(
|
||||
lambda v: v < 4.,
|
||||
lambda v: v * x2, [c],
|
||||
return_same_structure=False,
|
||||
name="while_3") # 3x
|
||||
ret4 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * x2 * x2, [c],
|
||||
return_same_structure=False,
|
||||
name="while_4") # 9x
|
||||
ret5 = while_loop_v2(
|
||||
lambda v: v < 16.,
|
||||
lambda v: v * v, [c],
|
||||
return_same_structure=False,
|
||||
name="while_stateless") # x**2
|
||||
return ret1, ret2, ret3, ret4, ret5
|
||||
|
||||
concrete_fn = Fn.get_concrete_function()
|
||||
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
||||
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
||||
while_3 = concrete_fn.graph.get_operation_by_name("while_3")
|
||||
while_4 = concrete_fn.graph.get_operation_by_name("while_4")
|
||||
while_stateless = concrete_fn.graph.get_operation_by_name(
|
||||
"while_stateless")
|
||||
self.assertEqual(while_1.type, "While")
|
||||
self.assertEqual(while_2.type, "While")
|
||||
self.assertEqual(while_3.type, "While")
|
||||
self.assertEqual(while_4.type, "While")
|
||||
self.assertEqual(while_stateless.type, "StatelessWhile")
|
||||
self.assertEmpty(while_1.control_inputs)
|
||||
self.assertLen(while_2.control_inputs, 1)
|
||||
self.assertIs(while_2.control_inputs[0], while_1)
|
||||
self.assertEmpty(while_3.control_inputs)
|
||||
self.assertLen(while_4.control_inputs, 1)
|
||||
self.assertIs(while_4.control_inputs[0], while_3)
|
||||
self.assertEmpty(while_stateless.control_inputs)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDoubleDerivative(self):
|
||||
x = constant_op.constant(2.)
|
||||
@ -360,7 +470,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
Cond, Body, [x, tensor_list], return_same_structure=False)
|
||||
|
||||
for op in ops.get_default_graph().get_operations():
|
||||
if op.type == "While":
|
||||
if op.type == "While" or op.type == "StatelessWhile":
|
||||
while_op = op
|
||||
|
||||
body_graph = while_v2._get_graph(while_op, "body")
|
||||
@ -443,7 +553,8 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
lambda i: i + 1, [constant_op.constant(0)],
|
||||
return_same_structure=False)
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "While")
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
self.assertEqual(while_op.type, "StatelessWhile")
|
||||
return while_op
|
||||
|
||||
def testDefaultName(self):
|
||||
@ -524,23 +635,24 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testForwardPassRewrite(self):
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
output = while_v2.while_loop(lambda x: x < 10.0,
|
||||
lambda x: x * 2.0,
|
||||
[x])[0]
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "While")
|
||||
# outputs = [loop_counter, max_iters, x]
|
||||
self.assertLen(while_op.outputs, 3)
|
||||
if compat.forward_compatible(2019, 8, 23):
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
output = while_v2.while_loop(lambda x: x < 10.0,
|
||||
lambda x: x * 2.0,
|
||||
[x])[0]
|
||||
while_op = output.op.inputs[0].op
|
||||
self.assertEqual(while_op.type, "StatelessWhile")
|
||||
# outputs = [loop_counter, max_iters, x]
|
||||
self.assertLen(while_op.outputs, 3)
|
||||
|
||||
gradients_impl.gradients(output, x)
|
||||
# while_op should have been rewritten to output 2.0 intermediate.
|
||||
# outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator]
|
||||
self.assertLen(while_op.outputs, 5)
|
||||
gradients_impl.gradients(output, x)
|
||||
# while_op should have been rewritten to output 2.0 intermediate.
|
||||
# outputs = [loop_counter, max_iters, x, 2.0_accumulator, x_accumulator]
|
||||
self.assertLen(while_op.outputs, 5)
|
||||
|
||||
gradients_impl.gradients(output, x)
|
||||
# Computing the gradient again shouldn't rewrite while_op again.
|
||||
self.assertLen(while_op.outputs, 5)
|
||||
gradients_impl.gradients(output, x)
|
||||
# Computing the gradient again shouldn't rewrite while_op again.
|
||||
self.assertLen(while_op.outputs, 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRandomUniformShape(self):
|
||||
|
@ -257,7 +257,8 @@ def _VerifyGeneratedGradients(grads, op):
|
||||
"""
|
||||
# While ops have inputs added to them during the gradient computation, so we
|
||||
# skip the below check. See while_v2 for details.
|
||||
if op.type == "While": return
|
||||
if op.type == "While" or op.type == "StatelessWhile":
|
||||
return
|
||||
|
||||
if len(grads) != len(op.inputs):
|
||||
raise ValueError("Num gradients %d generated for op %s do not match num "
|
||||
|
@ -23,6 +23,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
@ -236,7 +237,23 @@ def while_loop(cond,
|
||||
first_loop_var_index + num_flattened_outputs)
|
||||
output_shapes[orig_loop_vars_range] = nest.flatten(
|
||||
shape_invariants, expand_composites=True)[orig_loop_vars_range]
|
||||
outputs = gen_functional_ops._while(
|
||||
|
||||
cond_stateful_ops = [
|
||||
op for op in cond_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
body_stateful_ops = [
|
||||
op for op in body_graph.get_operations() if op._is_stateful
|
||||
]
|
||||
# TODO(yanhuasun): Remove this after Aug 23, 2019. This is required to
|
||||
# abide by 3-week forward compat window of new TF python op generating
|
||||
# code with stale runtime binaries.
|
||||
if (cond_stateful_ops or body_stateful_ops or
|
||||
not compat.forward_compatible(2019, 8, 23)):
|
||||
op_fn = gen_functional_ops._while
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_while
|
||||
|
||||
outputs = op_fn(
|
||||
flattened_loop_vars,
|
||||
util.create_new_tf_function(cond_graph),
|
||||
util.create_new_tf_function(body_graph),
|
||||
@ -270,6 +287,7 @@ def while_loop(cond,
|
||||
return outputs
|
||||
|
||||
|
||||
@ops.RegisterGradient("StatelessWhile")
|
||||
@ops.RegisterGradient("While")
|
||||
def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
||||
"""The gradient of a While op produced by while_loop."""
|
||||
|
@ -4062,7 +4062,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessWhile"
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StaticRegexFullMatch"
|
||||
|
@ -4062,7 +4062,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessWhile"
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StaticRegexFullMatch"
|
||||
|
Loading…
Reference in New Issue
Block a user