generate stateless_case op if all ops in all branches are stateless

This avoids unnecessary auto dependency due to stateful ops

PiperOrigin-RevId: 324105758
Change-Id: Icf7979cca19f283ce7fd4cc338b4182f5e2e3b51
This commit is contained in:
Yanhua Sun 2020-07-30 16:20:34 -07:00 committed by TensorFlower Gardener
parent a4725be4a7
commit b5a2876c65
18 changed files with 191 additions and 29 deletions

View File

@ -58,7 +58,8 @@
benavior.
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
the same sparsity pattern, but with new provided values. It is similar to
the `with_values` function of `RaggedTensor`.
the `with_values` function of `RaggedTensor`.
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
* `tf.data`:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one process

View File

@ -2014,6 +2014,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"StatefulUniform",
"StatefulUniformFullInt",
"StatefulUniformInt",
"StatelessCase",
"StatelessIf",
"StatelessMultinomial",
"StatelessRandomNormal",

View File

@ -140,7 +140,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
GetFunctionBody(flib_runtime, node, "else_branch", &felse));
return CondConstInputIndices({fthen, felse}, const_input_idxs,
flib_runtime);
} else if (node.op() == "Case") {
} else if (node.op() == "Case" || node.op() == "StatelessCase") {
std::vector<const FunctionBody*> branch_bodies;
TF_RETURN_IF_ERROR(
GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));

View File

@ -371,5 +371,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(),
XlaCaseOp);
REGISTER_XLA_OP(Name("StatelessCase").AllowResourceTypes().AllowVariantTypes(),
XlaCaseOp);
} // namespace tensorflow

View File

@ -0,0 +1,46 @@
op {
graph_op_name: "StatelessCase"
visibility: HIDDEN
in_arg {
name: "branch_index"
description: "The branch selector, an int32 Tensor."
}
in_arg {
name: "input"
description: "A list of input tensors passed to the branch function."
}
out_arg {
name: "output"
description: "A list of return values."
}
attr { name: "Tin" description: "A list of input types." }
attr { name: "Tout" description: "A list of output types." }
attr {
name: "branches"
description: <<END
A list of functions each of which takes 'inputs' and returns a list of
tensors, whose types are the same as what every other branch returns.
END
}
summary: "An n-way switch statement which calls a single branch function."
description: <<END
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
This should only be used when the none of branches has stateful ops.
END
}

View File

@ -159,7 +159,7 @@ Status LowerFunctionalOpsPass::Run(
if (n->IsIfNode() && lower_control_flow(n)) {
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
} else if (n->type_string() == "Case" && lower_control_flow(n)) {
} else if (n->IsCaseNode() && lower_control_flow(n)) {
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
} else if (n->IsWhileNode() && lower_control_flow(n)) {

View File

@ -82,6 +82,8 @@ Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
{"StatelessIf", NC_IF},
{"While", NC_WHILE},
{"StatelessWhile", NC_WHILE},
{"Case", NC_CASE},
{"StatelessCase", NC_CASE},
// Not using the constants defined in FunctionLibraryDefinition
// for the
// 4 ops below because android inference library does not link

View File

@ -189,6 +189,7 @@ class Node {
bool IsIfNode() const { return class_ == NC_IF; }
bool IsWhileNode() const { return class_ == NC_WHILE; }
bool IsCaseNode() const { return class_ == NC_CASE; }
// Is this node a function input
bool IsArg() const { return class_ == NC_ARG; }
// Is this node a function output
@ -282,6 +283,7 @@ class Node {
NC_SYMBOLIC_GRADIENT,
NC_IF,
NC_WHILE,
NC_CASE,
NC_ARG,
NC_RETVAL,
NC_OTHER // Not a special kind of node

View File

@ -1252,7 +1252,7 @@ Status InlineFunctionCalls(const GrapplerItem& item,
if (n->IsIfNode()) {
TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
} else if (n->type_string() == "Case") {
} else if (n->IsCaseNode()) {
TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
} else if (n->IsWhileNode()) {
TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), false));

View File

@ -41,6 +41,7 @@ namespace grappler {
constexpr char kConstOp[] = "Const";
constexpr char kCaseOp[] = "Case";
constexpr char kStatelessCaseOp[] = "StatelessCase";
constexpr char kDeviceIndexOp[] = "DeviceIndex";
// TODO(b/157615690): clean up function implementation swap code.
@ -353,7 +354,9 @@ Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
// case node.
for (const auto& fanouts : node_view->GetRegularFanouts()) {
for (const auto& fanout : fanouts) {
if (fanout.node_view()->GetOp() != kCaseOp) continue;
if (fanout.node_view()->GetOp() != kCaseOp &&
fanout.node_view()->GetOp() != kStatelessCaseOp)
continue;
int index;
// If any error is thrown out during device parsing, we simply skip
// and do not modify the DeviceIndexNode.

View File

@ -83,6 +83,30 @@ TEST_F(ImplementationSelectorTest, SelectDeviceIndex) {
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexStatelessCase) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "StatelessCase", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
}
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) {
using test::function::NDef;
ImplementationSelector optimizer;

View File

@ -354,6 +354,10 @@ REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
REGISTER_KERNEL_BUILDER(
Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
REGISTER_KERNEL_BUILDER(Name("StatelessCase").Device(DEVICE_CPU), CaseOp);
REGISTER_KERNEL_BUILDER(
Name("StatelessCase").Device(DEVICE_GPU).HostMemory("branch_index"),
CaseOp);
REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
REGISTER_KERNEL_BUILDER(

View File

@ -135,6 +135,36 @@ REGISTER_OP("If")
.SetIsStateful()
.SetShapeFn(IfShapeInferenceFn);
Status CaseShapeInferenceFn(shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
// If `output_shapes` attr is set use that as the shapes of the outputs
// else return unknown shapes.
if (output_shapes.empty()) return shape_inference::UnknownShape(c);
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as num outputs (",
output_shapes.size(), " vs. ", c->num_outputs());
}
for (size_t i = 0; i < output_shapes.size(); ++i) {
shape_inference::ShapeHandle output_shape_handle;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
output_shapes[i], &output_shape_handle));
c->set_output(static_cast<int>(i), output_shape_handle);
}
return Status::OK();
}
REGISTER_OP("StatelessCase")
.Input("branch_index: int32")
.Input("input: Tin")
.Output("output: Tout")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >= 0")
.Attr("branches: list(func) >= 1")
.Attr("output_shapes: list(shape) = []")
.SetShapeFn(CaseShapeInferenceFn);
REGISTER_OP("Case")
.Input("branch_index: int32")
.Input("input: Tin")
@ -144,25 +174,7 @@ REGISTER_OP("Case")
.Attr("branches: list(func) >= 1")
.Attr("output_shapes: list(shape) = []")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
// If `output_shapes` attr is set use that as the shapes of the outputs
// else return unknown shapes.
if (output_shapes.empty()) return shape_inference::UnknownShape(c);
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as num outputs (",
output_shapes.size(), " vs. ", c->num_outputs());
}
for (size_t i = 0; i < output_shapes.size(); ++i) {
shape_inference::ShapeHandle output_shape_handle;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
output_shapes[i], &output_shape_handle));
c->set_output(static_cast<int>(i), output_shape_handle);
}
return Status::OK();
});
.SetShapeFn(CaseShapeInferenceFn);
// TODO(drpng): remove this.
REGISTER_OP("_While")

View File

@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 348> a = {{
static std::array<OpIndexInfo, 349> a = {{
{"Acosh"},
{"AllToAll", 1, {0}},
{"ApproximateEqual"},
@ -324,6 +324,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
{"StackClose"},
{"StackPop"},
{"StackPush"},
{"StatelessCase"},
{"StatelessMultinomial"},
{"StatelessParameterizedTruncatedNormal", 1, {1}},
{"StatelessRandomBinomial"},

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat.compat import forward_compatibility_horizon
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -34,6 +35,7 @@ from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
@ -1519,7 +1521,43 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
run_metadata = config_pb2.RunMetadata()
sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
self.assertGreaterEqual(len(run_metadata.partition_graphs), 2)
class CaseTest(test.TestCase):
def testCase(self):
def branch1(x):
logging_ops.print_v2("1")
return x
def branch2(x):
return x + 1
with ops.Graph().as_default():
x = array_ops.constant(1)
output = cond_v2.indexed_case(
array_ops.constant(0), [lambda: branch1(x), lambda: branch2(x)])
cond_op = output.op.inputs[0].op
self.assertEqual(cond_op.type, "Case")
self.assertEqual(1., self.evaluate(output))
def testStatelessCase(self):
def branch1(x):
return x + 1
def branch2(x):
return x + 2
with ops.Graph().as_default():
x = array_ops.constant(1)
output = cond_v2.indexed_case(
array_ops.constant(0), [lambda: branch1(x), lambda: branch2(x)])
cond_op = output.op.inputs[0].op
self.assertEqual(cond_op.type, "StatelessCase")
self.assertEqual(2., self.evaluate(output))
def _cond(pred, true_fn, false_fn, name):
@ -1544,4 +1582,5 @@ def _has_node_with_op(run_metadata, op_type):
if __name__ == "__main__":
test.main()
with forward_compatibility_horizon(2020, 8, 21):
test.main()

View File

@ -25,7 +25,9 @@ from __future__ import print_function
import collections
from tensorflow.python.compat import compat
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -330,7 +332,7 @@ def get_func_graphs(op):
op.get_attr("then_branch"), "_true_graph"),
_get_func_graph_for_branch(
op.get_attr("else_branch"), "_false_graph"))
elif op.type == "Case":
elif op.type in ["Case", "StatelessCase"]:
# TODO(b/141114088): investigate whether to cache graphs in forward pass
return [_get_func_graph_for_branch(branch_fn)
for branch_fn in op.get_attr("branches")]
@ -985,6 +987,7 @@ def indexed_case(branch_index,
@ops.RegisterGradient("Case")
@ops.RegisterGradient("StatelessCase")
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of a Case op produced by tf.switch_case."""
# Get the Case operator (this logic handles the case where op is a MockOp)
@ -1111,10 +1114,24 @@ def _build_case(branch_index,
# graphs in `branch_graphs`.
case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
stateful_ops = []
for bg in branch_graphs:
stateful_ops.extend([
op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
])
# TODO(b/161915509): Remove this after 08/20/2020. This is required to abide
# by 3-week forward compat window of new TF python op generating code with
# stale runtime binaries.
if (stateful_ops or not compat.forward_compatible(2020, 8, 20)):
op_fn = gen_functional_ops.case
else:
op_fn = gen_functional_ops.stateless_case
# Create the Case op.
with ops.control_dependencies(
sum((list(bg.control_captures) for bg in branch_graphs), [])):
tensors = gen_functional_ops.case(
tensors = op_fn(
branch_index,
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
[util.create_new_tf_function(g) for g in branch_graphs],

View File

@ -4508,6 +4508,10 @@ tf_module {
name: "StatefulUniformInt"
argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessCase"
argspec: "args=[\'branch_index\', \'input\', \'Tout\', \'branches\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
}
member_method {
name: "StatelessIf"
argspec: "args=[\'cond\', \'input\', \'Tout\', \'then_branch\', \'else_branch\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "

View File

@ -4508,6 +4508,10 @@ tf_module {
name: "StatefulUniformInt"
argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessCase"
argspec: "args=[\'branch_index\', \'input\', \'Tout\', \'branches\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
}
member_method {
name: "StatelessIf"
argspec: "args=[\'cond\', \'input\', \'Tout\', \'then_branch\', \'else_branch\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "