generate stateless_case op if all ops in all branches are stateless
This avoids unnecessary auto dependency due to stateful ops PiperOrigin-RevId: 324105758 Change-Id: Icf7979cca19f283ce7fd4cc338b4182f5e2e3b51
This commit is contained in:
parent
a4725be4a7
commit
b5a2876c65
RELEASE.md
tensorflow
compiler
core
api_def/base_api
common_runtime
graph
grappler/optimizers
kernels
ops
python
tools/api/golden
@ -58,7 +58,8 @@
|
||||
benavior.
|
||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
|
||||
the same sparsity pattern, but with new provided values. It is similar to
|
||||
the `with_values` function of `RaggedTensor`.
|
||||
the `with_values` function of `RaggedTensor`.
|
||||
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
|
||||
* `tf.data`:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
|
||||
|
@ -2014,6 +2014,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"StatefulUniform",
|
||||
"StatefulUniformFullInt",
|
||||
"StatefulUniformInt",
|
||||
"StatelessCase",
|
||||
"StatelessIf",
|
||||
"StatelessMultinomial",
|
||||
"StatelessRandomNormal",
|
||||
|
@ -140,7 +140,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
|
||||
GetFunctionBody(flib_runtime, node, "else_branch", &felse));
|
||||
return CondConstInputIndices({fthen, felse}, const_input_idxs,
|
||||
flib_runtime);
|
||||
} else if (node.op() == "Case") {
|
||||
} else if (node.op() == "Case" || node.op() == "StatelessCase") {
|
||||
std::vector<const FunctionBody*> branch_bodies;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
|
||||
|
@ -371,5 +371,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
|
||||
REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(),
|
||||
XlaCaseOp);
|
||||
REGISTER_XLA_OP(Name("StatelessCase").AllowResourceTypes().AllowVariantTypes(),
|
||||
XlaCaseOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
46
tensorflow/core/api_def/base_api/api_def_StatelessCase.pbtxt
Normal file
46
tensorflow/core/api_def/base_api/api_def_StatelessCase.pbtxt
Normal file
@ -0,0 +1,46 @@
|
||||
op {
|
||||
graph_op_name: "StatelessCase"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "branch_index"
|
||||
description: "The branch selector, an int32 Tensor."
|
||||
}
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: "A list of input tensors passed to the branch function."
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: "A list of return values."
|
||||
}
|
||||
attr { name: "Tin" description: "A list of input types." }
|
||||
attr { name: "Tout" description: "A list of output types." }
|
||||
attr {
|
||||
name: "branches"
|
||||
description: <<END
|
||||
A list of functions each of which takes 'inputs' and returns a list of
|
||||
tensors, whose types are the same as what every other branch returns.
|
||||
END
|
||||
}
|
||||
summary: "An n-way switch statement which calls a single branch function."
|
||||
description: <<END
|
||||
An n-way switch statement, implementing the following:
|
||||
```
|
||||
switch (branch_index) {
|
||||
case 0:
|
||||
output = branches[0](input);
|
||||
break;
|
||||
case 1:
|
||||
output = branches[1](input);
|
||||
break;
|
||||
...
|
||||
case [[nbranches-1]]:
|
||||
default:
|
||||
output = branches[nbranches-1](input);
|
||||
break;
|
||||
}
|
||||
```
|
||||
|
||||
This should only be used when the none of branches has stateful ops.
|
||||
END
|
||||
}
|
@ -159,7 +159,7 @@ Status LowerFunctionalOpsPass::Run(
|
||||
if (n->IsIfNode() && lower_control_flow(n)) {
|
||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
|
||||
|
||||
} else if (n->type_string() == "Case" && lower_control_flow(n)) {
|
||||
} else if (n->IsCaseNode() && lower_control_flow(n)) {
|
||||
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
|
||||
|
||||
} else if (n->IsWhileNode() && lower_control_flow(n)) {
|
||||
|
@ -82,6 +82,8 @@ Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
|
||||
{"StatelessIf", NC_IF},
|
||||
{"While", NC_WHILE},
|
||||
{"StatelessWhile", NC_WHILE},
|
||||
{"Case", NC_CASE},
|
||||
{"StatelessCase", NC_CASE},
|
||||
// Not using the constants defined in FunctionLibraryDefinition
|
||||
// for the
|
||||
// 4 ops below because android inference library does not link
|
||||
|
@ -189,6 +189,7 @@ class Node {
|
||||
|
||||
bool IsIfNode() const { return class_ == NC_IF; }
|
||||
bool IsWhileNode() const { return class_ == NC_WHILE; }
|
||||
bool IsCaseNode() const { return class_ == NC_CASE; }
|
||||
// Is this node a function input
|
||||
bool IsArg() const { return class_ == NC_ARG; }
|
||||
// Is this node a function output
|
||||
@ -282,6 +283,7 @@ class Node {
|
||||
NC_SYMBOLIC_GRADIENT,
|
||||
NC_IF,
|
||||
NC_WHILE,
|
||||
NC_CASE,
|
||||
NC_ARG,
|
||||
NC_RETVAL,
|
||||
NC_OTHER // Not a special kind of node
|
||||
|
@ -1252,7 +1252,7 @@ Status InlineFunctionCalls(const GrapplerItem& item,
|
||||
|
||||
if (n->IsIfNode()) {
|
||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
|
||||
} else if (n->type_string() == "Case") {
|
||||
} else if (n->IsCaseNode()) {
|
||||
TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
|
||||
} else if (n->IsWhileNode()) {
|
||||
TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), false));
|
||||
|
@ -41,6 +41,7 @@ namespace grappler {
|
||||
|
||||
constexpr char kConstOp[] = "Const";
|
||||
constexpr char kCaseOp[] = "Case";
|
||||
constexpr char kStatelessCaseOp[] = "StatelessCase";
|
||||
constexpr char kDeviceIndexOp[] = "DeviceIndex";
|
||||
|
||||
// TODO(b/157615690): clean up function implementation swap code.
|
||||
@ -353,7 +354,9 @@ Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
|
||||
// case node.
|
||||
for (const auto& fanouts : node_view->GetRegularFanouts()) {
|
||||
for (const auto& fanout : fanouts) {
|
||||
if (fanout.node_view()->GetOp() != kCaseOp) continue;
|
||||
if (fanout.node_view()->GetOp() != kCaseOp &&
|
||||
fanout.node_view()->GetOp() != kStatelessCaseOp)
|
||||
continue;
|
||||
int index;
|
||||
// If any error is thrown out during device parsing, we simply skip
|
||||
// and do not modify the DeviceIndexNode.
|
||||
|
@ -83,6 +83,30 @@ TEST_F(ImplementationSelectorTest, SelectDeviceIndex) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexStatelessCase) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "StatelessCase", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
|
@ -354,6 +354,10 @@ REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessCase").Device(DEVICE_CPU), CaseOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatelessCase").Device(DEVICE_GPU).HostMemory("branch_index"),
|
||||
CaseOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
|
@ -135,6 +135,36 @@ REGISTER_OP("If")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(IfShapeInferenceFn);
|
||||
|
||||
Status CaseShapeInferenceFn(shape_inference::InferenceContext* c) {
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
|
||||
// If `output_shapes` attr is set use that as the shapes of the outputs
|
||||
// else return unknown shapes.
|
||||
if (output_shapes.empty()) return shape_inference::UnknownShape(c);
|
||||
if (output_shapes.size() != c->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"`output_shapes` must be the same length as num outputs (",
|
||||
output_shapes.size(), " vs. ", c->num_outputs());
|
||||
}
|
||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||
shape_inference::ShapeHandle output_shape_handle;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
|
||||
output_shapes[i], &output_shape_handle));
|
||||
c->set_output(static_cast<int>(i), output_shape_handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("StatelessCase")
|
||||
.Input("branch_index: int32")
|
||||
.Input("input: Tin")
|
||||
.Output("output: Tout")
|
||||
.Attr("Tin: list(type) >= 0")
|
||||
.Attr("Tout: list(type) >= 0")
|
||||
.Attr("branches: list(func) >= 1")
|
||||
.Attr("output_shapes: list(shape) = []")
|
||||
.SetShapeFn(CaseShapeInferenceFn);
|
||||
|
||||
REGISTER_OP("Case")
|
||||
.Input("branch_index: int32")
|
||||
.Input("input: Tin")
|
||||
@ -144,25 +174,7 @@ REGISTER_OP("Case")
|
||||
.Attr("branches: list(func) >= 1")
|
||||
.Attr("output_shapes: list(shape) = []")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
|
||||
// If `output_shapes` attr is set use that as the shapes of the outputs
|
||||
// else return unknown shapes.
|
||||
if (output_shapes.empty()) return shape_inference::UnknownShape(c);
|
||||
if (output_shapes.size() != c->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"`output_shapes` must be the same length as num outputs (",
|
||||
output_shapes.size(), " vs. ", c->num_outputs());
|
||||
}
|
||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||
shape_inference::ShapeHandle output_shape_handle;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
|
||||
output_shapes[i], &output_shape_handle));
|
||||
c->set_output(static_cast<int>(i), output_shape_handle);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
.SetShapeFn(CaseShapeInferenceFn);
|
||||
|
||||
// TODO(drpng): remove this.
|
||||
REGISTER_OP("_While")
|
||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 348> a = {{
|
||||
static std::array<OpIndexInfo, 349> a = {{
|
||||
{"Acosh"},
|
||||
{"AllToAll", 1, {0}},
|
||||
{"ApproximateEqual"},
|
||||
@ -324,6 +324,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
{"StackClose"},
|
||||
{"StackPop"},
|
||||
{"StackPush"},
|
||||
{"StatelessCase"},
|
||||
{"StatelessMultinomial"},
|
||||
{"StatelessParameterizedTruncatedNormal", 1, {1}},
|
||||
{"StatelessRandomBinomial"},
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.compat.compat import forward_compatibility_horizon
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -34,6 +35,7 @@ from tensorflow.python.ops import cond_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -1519,7 +1521,43 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
|
||||
self.assertGreaterEqual(len(run_metadata.partition_graphs), 2)
|
||||
|
||||
|
||||
class CaseTest(test.TestCase):
|
||||
|
||||
def testCase(self):
|
||||
|
||||
def branch1(x):
|
||||
logging_ops.print_v2("1")
|
||||
return x
|
||||
|
||||
def branch2(x):
|
||||
return x + 1
|
||||
|
||||
with ops.Graph().as_default():
|
||||
x = array_ops.constant(1)
|
||||
output = cond_v2.indexed_case(
|
||||
array_ops.constant(0), [lambda: branch1(x), lambda: branch2(x)])
|
||||
cond_op = output.op.inputs[0].op
|
||||
self.assertEqual(cond_op.type, "Case")
|
||||
self.assertEqual(1., self.evaluate(output))
|
||||
|
||||
def testStatelessCase(self):
|
||||
|
||||
def branch1(x):
|
||||
return x + 1
|
||||
|
||||
def branch2(x):
|
||||
return x + 2
|
||||
|
||||
with ops.Graph().as_default():
|
||||
x = array_ops.constant(1)
|
||||
output = cond_v2.indexed_case(
|
||||
array_ops.constant(0), [lambda: branch1(x), lambda: branch2(x)])
|
||||
cond_op = output.op.inputs[0].op
|
||||
self.assertEqual(cond_op.type, "StatelessCase")
|
||||
self.assertEqual(2., self.evaluate(output))
|
||||
|
||||
|
||||
def _cond(pred, true_fn, false_fn, name):
|
||||
@ -1544,4 +1582,5 @@ def _has_node_with_op(run_metadata, op_type):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
with forward_compatibility_horizon(2020, 8, 21):
|
||||
test.main()
|
||||
|
@ -25,7 +25,9 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import auto_control_deps_utils as acd
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -330,7 +332,7 @@ def get_func_graphs(op):
|
||||
op.get_attr("then_branch"), "_true_graph"),
|
||||
_get_func_graph_for_branch(
|
||||
op.get_attr("else_branch"), "_false_graph"))
|
||||
elif op.type == "Case":
|
||||
elif op.type in ["Case", "StatelessCase"]:
|
||||
# TODO(b/141114088): investigate whether to cache graphs in forward pass
|
||||
return [_get_func_graph_for_branch(branch_fn)
|
||||
for branch_fn in op.get_attr("branches")]
|
||||
@ -985,6 +987,7 @@ def indexed_case(branch_index,
|
||||
|
||||
|
||||
@ops.RegisterGradient("Case")
|
||||
@ops.RegisterGradient("StatelessCase")
|
||||
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
||||
"""The gradient of a Case op produced by tf.switch_case."""
|
||||
# Get the Case operator (this logic handles the case where op is a MockOp)
|
||||
@ -1111,10 +1114,24 @@ def _build_case(branch_index,
|
||||
# graphs in `branch_graphs`.
|
||||
case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
|
||||
|
||||
stateful_ops = []
|
||||
for bg in branch_graphs:
|
||||
stateful_ops.extend([
|
||||
op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
|
||||
])
|
||||
|
||||
# TODO(b/161915509): Remove this after 08/20/2020. This is required to abide
|
||||
# by 3-week forward compat window of new TF python op generating code with
|
||||
# stale runtime binaries.
|
||||
if (stateful_ops or not compat.forward_compatible(2020, 8, 20)):
|
||||
op_fn = gen_functional_ops.case
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_case
|
||||
|
||||
# Create the Case op.
|
||||
with ops.control_dependencies(
|
||||
sum((list(bg.control_captures) for bg in branch_graphs), [])):
|
||||
tensors = gen_functional_ops.case(
|
||||
tensors = op_fn(
|
||||
branch_index,
|
||||
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
|
||||
[util.create_new_tf_function(g) for g in branch_graphs],
|
||||
|
@ -4508,6 +4508,10 @@ tf_module {
|
||||
name: "StatefulUniformInt"
|
||||
argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessCase"
|
||||
argspec: "args=[\'branch_index\', \'input\', \'Tout\', \'branches\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessIf"
|
||||
argspec: "args=[\'cond\', \'input\', \'Tout\', \'then_branch\', \'else_branch\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
|
@ -4508,6 +4508,10 @@ tf_module {
|
||||
name: "StatefulUniformInt"
|
||||
argspec: "args=[\'resource\', \'algorithm\', \'shape\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessCase"
|
||||
argspec: "args=[\'branch_index\', \'input\', \'Tout\', \'branches\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessIf"
|
||||
argspec: "args=[\'cond\', \'input\', \'Tout\', \'then_branch\', \'else_branch\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user