Adds a lowering from Case to _SwitchN+Merge.

Introduces a new n-way _SwitchN op+kernel.

I audited usages of the grappler and graph variants of IsSwitch and IsMerge, and believe the corrections in this CL are correct.

PiperOrigin-RevId: 250803634
This commit is contained in:
Brian Patton 2019-05-30 18:23:06 -07:00 committed by TensorFlower Gardener
parent 586ff7b1fa
commit 10ed2f7bb5
20 changed files with 1211 additions and 117 deletions

View File

@ -846,24 +846,42 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
const Edge* pred_edge;
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
Predicate* true_switch;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
pred_edge->src(), pred_edge->src_output(),
/*must_be_true=*/true, &true_switch));
if (n->num_outputs() == 2) {
Predicate* true_switch;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
pred_edge->src(), pred_edge->src_output(),
/*must_be_true=*/true, &true_switch));
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
// Output 0 is alive iff all inputs are alive and the condition is false.
input_preds.push_back(false_switch);
SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
should_revisit);
input_preds.pop_back();
// Output 0 is alive iff all inputs are alive and the condition is false.
input_preds.push_back(false_switch);
SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
should_revisit);
input_preds.pop_back();
// Output 1 is alive iff all inputs are alive and the condition is true.
input_preds.push_back(true_switch);
SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
should_revisit);
input_preds.pop_back();
// Output 1 is alive iff all inputs are alive and the condition is true.
input_preds.push_back(true_switch);
SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
should_revisit);
input_preds.pop_back();
} else { // N-way switch case. Exactly one of N branches is alive.
Predicate* branch_pred;
for (int i = 0; i < n->num_outputs() - 1; i++) {
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
n, i, /*must_be_true=*/false, &branch_pred));
input_preds.push_back(branch_pred);
SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds),
should_revisit);
input_preds.pop_back();
input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred));
}
// The default (last) branch does not need its own symbol, is simply the
// nor of all other branches.
SetPredicate(n, n->num_outputs() - 1,
predicate_factory_.MakeAndPredicate(input_preds),
should_revisit);
}
// Control is alive iff all inputs are alive.
SetPredicate(n, Graph::kControlSlot,

View File

@ -3171,6 +3171,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/local_device.h",
"common_runtime/lower_function_call_op.h",
"common_runtime/lower_if_op.h",
"common_runtime/lower_case_op.h",
"common_runtime/lower_functional_ops.h",
"common_runtime/lower_while_op.h",
"common_runtime/memory_types.h",
@ -3232,6 +3233,7 @@ tf_cuda_library(
"common_runtime/inspecting_placer.h",
"common_runtime/isolate_placer_inspection_required_ops_pass.cc",
"common_runtime/local_device.cc",
"common_runtime/lower_case_op.cc",
"common_runtime/lower_function_call_op.cc",
"common_runtime/lower_functional_ops.cc",
"common_runtime/lower_if_op.cc",
@ -5326,6 +5328,30 @@ tf_cc_tests(
],
)
tf_cc_tests(
name = "common_runtime_lower_case_op_test",
size = "small",
srcs = ["common_runtime/lower_case_op_test.cc"],
deps = [
":all_kernels",
":core_cpu",
":core_cpu_internal",
":direct_session",
":framework",
":framework_internal",
":lib",
":test",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:client_session",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
],
)
tf_cc_tests(
name = "common_runtime_lower_while_op_test",
size = "small",

View File

@ -0,0 +1,300 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_case_op.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
using NodeOut = NodeBuilder::NodeOut;
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
// Convenience builder to make it easy to construct a case with a single
// function call in each branch. This first converts the Case node
// into switches (for inputs) and merges (for outputs) around a function call
// per branch.
class CaseBuilder {
public:
// Create a CaseBuilder to create the lowered form of `case` with branch
// functions identified by `branch_fn_names` in the `graph`. The functions
// should be available in `flib`.
CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names,
const FunctionLibraryDefinition& flib, bool keep_node_fetchable,
Graph* graph);
// Constructs the basic conditional control flow using switch and merge nodes.
Status CreatePivotNodes();
// Adds the inputs from the if node to the merge nodes of the lowered if.
Status AddInputs();
// Adds the outputs from the if node to the merge nodes of the lowered if.
// Note: no inputs can be added once outputs are added as the then and else
// nodes are finalized while adding outputs.
Status AddOutputs();
// Builds an identity node with the same outputs as Case.
Status BuildLoweredCaseOutput();
private:
// Returns unique name containing the name of the Case op being rewritten
// (name_), infix and a suffix to ensure it is unique within the graph.
string NewName(const string& infix);
// Adds input to both the then and else nodes from src:src_output.
Status AddInput(Node* src, int src_output);
// The merged outputs of the then and else nodes.
std::vector<NodeOut> outputs_;
// The node that dominates all execution of the then and else body nodes.
Node* control_predecessor_;
// The original Case op.
Node* case_op_;
// The node with the same name as the original Case op:
// (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
// and if the original Case op had non-zero data outputs.
// (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
Node* lowered_case_output_;
// The branch selector of the case.
OutputTensor branch_index_;
int num_branches_;
// Nodes corresponding to pivot branch of branch_index _SwitchN, which is
// the pivot node that dominates all nodes in the i'th branch.
std::vector<Node*> pivots_;
std::vector<Node*> call_nodes_;
// Merge node that has inputs from each of pivots_ and control edges from
// [^call_node for call_node in call_nodes_]. This node will guarantee that
// even when branch functions do not have outputs, they still will be executed
// for the side effects.
Node* branch_executed_node_;
Graph* graph_;
const FunctionLibraryDefinition& flib_;
string name_;
bool keep_node_fetchable_;
NodeDebugInfo debug_info_;
std::vector<NodeBuilder> branch_call_builders_;
};
CaseBuilder::CaseBuilder(Node* case_op,
const std::vector<string>& branch_fn_names,
const FunctionLibraryDefinition& flib,
bool keep_node_fetchable, Graph* graph)
: case_op_(case_op),
num_branches_(branch_fn_names.size()),
graph_(graph),
flib_(flib),
name_(case_op->name()),
keep_node_fetchable_(keep_node_fetchable),
debug_info_(*case_op_) {
branch_call_builders_.reserve(num_branches_);
for (int b = 0; b < num_branches_; b++) {
branch_call_builders_.emplace_back(NewName(strings::StrCat("branch", b)),
branch_fn_names[b], graph->op_registry(),
&debug_info_);
branch_call_builders_[b].Device(case_op_->requested_device());
branch_call_builders_[b].Attr(kLowerAsMultiDeviceFunctionAttr, true);
}
TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_));
}
Status CaseBuilder::CreatePivotNodes() {
// Construct the basic case body (consisting of feeding in the val to
// create pivot nodes).
Node* branch_index;
TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_index"), "_SwitchN",
graph_->op_registry(), &debug_info_)
.Input(NodeOut(branch_index_))
.Input(NodeOut(branch_index_))
.Attr("num_outs", num_branches_)
.Device(case_op_->requested_device())
.Finalize(graph_, &branch_index));
control_predecessor_ = branch_index;
pivots_.resize(num_branches_, nullptr);
for (int b = 0; b < num_branches_; b++) {
TF_RETURN_IF_ERROR(NodeBuilder(NewName(strings::StrCat("pivot_", b)),
"Identity", graph_->op_registry(),
&debug_info_)
.Input(branch_index, b)
.Device(case_op_->requested_device())
.Finalize(graph_, &pivots_[b]));
}
return Status::OK();
}
string CaseBuilder::NewName(const string& infix) {
return graph_->NewName(strings::StrCat(name_, "/", infix));
}
Status CaseBuilder::AddInput(Node* src, int src_output) {
Node* input;
NodeDebugInfo debug_info(*src);
// Colocate the Switch node with the `src` node.
//
// This is to avoid unnecessary Host<->Device copies between src and the
// _SwitchN node. This aligns with the implementation of legacy tf.cond in
// control_flow_ops.py. The legacy impl colocates the Switch with the
// input tensor which resets the device stack and forces the Switch to have
// the same device as the input node (if set) and sets the colocation _class
// attr. It also ignores the existing colocation constraints on the input node
// using colocate_with(ignore_existing=True).
TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "_SwitchN",
graph_->op_registry(), &debug_info)
.Input(src, src_output)
.Input(branch_index_)
.Device(src->requested_device())
.Attr("_class", {src->name()})
.Attr("num_outs", num_branches_)
.Finalize(graph_, &input));
for (int b = 0; b < num_branches_; b++) {
branch_call_builders_[b].Input(input, b);
}
return Status::OK();
}
Status CaseBuilder::AddInputs() {
// Add input data edges.
std::vector<const Edge*> edges;
TF_RETURN_IF_ERROR(case_op_->input_edges(&edges));
// Start at index 1 as the first input is the branch index.
for (int i = 1; i < edges.size(); ++i) {
const Edge* e = edges[i];
TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
}
// Add input control edges.
for (const Edge* e : case_op_->in_edges()) {
if (e->IsControlEdge()) {
graph_->AddControlEdge(e->src(), control_predecessor_);
}
}
return Status::OK();
}
Status CaseBuilder::AddOutputs() {
// Construct the call nodes for each branch.
call_nodes_.resize(num_branches_, nullptr);
for (int b = 0; b < num_branches_; b++) {
TF_RETURN_IF_ERROR(
branch_call_builders_[b].Finalize(graph_, &call_nodes_[b]));
graph_->AddControlEdge(pivots_[b], call_nodes_[b]);
}
// Merge the outputs from the N branches (all branches have matching outputs).
const int num_outputs = call_nodes_[0]->num_outputs();
std::vector<Node*> merges(num_outputs);
outputs_.resize(merges.size());
for (int i = 0; i < num_outputs; ++i) {
std::vector<NodeOut> merge_input;
merge_input.reserve(num_branches_);
for (int j = 0; j < num_branches_; j++) {
merge_input.emplace_back(call_nodes_[j], i);
}
TF_RETURN_IF_ERROR(NodeBuilder(NewName("merge"), "Merge",
graph_->op_registry(), &debug_info_)
.Input(merge_input)
.Device(case_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
// Add a Merge node that will be used as a control dependency source for the
// lowered output node. This Merge node will guarantee that lowered else/then
// function calls will be executed even if they do not have data outputs.
//
// Furthermore it will guarantee that all function side effects will be
// executed, if the function will be inlined into the graph. Having data
// outputs is not enough, because they might become unused after inlining.
//
// We will use this node to rewrite outgoing control edges from lowered 'Case'
// node. All data edges will read tensors directly from Merge nodes.
std::vector<NodeOut> pivots(num_branches_);
for (int j = 0; j < num_branches_; j++) {
pivots[j] = NodeOut(pivots_[j]);
}
TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge",
graph_->op_registry(), &debug_info_)
.Input(pivots)
.ControlInputs(call_nodes_)
.Device(case_op_->requested_device())
.Finalize(graph_, &branch_executed_node_));
TF_RETURN_IF_ERROR(BuildLoweredCaseOutput());
// Add outputs.
for (const Edge* e : case_op_->out_edges()) {
if (e->IsControlEdge()) {
graph_->AddControlEdge(branch_executed_node_, e->dst());
} else {
// Feed the outputs directly from the merge nodes so that downstream ops
// can start before all the outputs have been computed.
graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
}
}
return Status::OK();
}
Status CaseBuilder::BuildLoweredCaseOutput() {
// If outputs are empty, it means that we might have only output control
// edges (already connected to the `branch_executed_node`). Furthermore it's
// illegal to have an IdentityN with empty inputs.
//
// We still must keep lowered Case node as a valid source of control edges,
// because it might be a part of function control output set.
NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
? NodeBuilder(name_, "IdentityN").Input(outputs_)
: NodeBuilder(name_, "NoOp");
return builder.Device(case_op_->requested_device())
.ControlInput(branch_executed_node_)
.Finalize(graph_, &lowered_case_output_);
}
} // namespace
Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
bool keep_node_fetchable) {
VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable
<< "): " << SummarizeNode(*n);
const AttrValue* branches_attr = n->attrs().Find("branches");
if (branches_attr == nullptr) {
return errors::InvalidArgument("branch functions missing");
}
int num_branches = branches_attr->list().func_size();
std::vector<string> branch_fn_names;
branch_fn_names.reserve(num_branches);
for (int b = 0; b < num_branches; b++) {
branch_fn_names.emplace_back(branches_attr->list().func(b).name());
}
CaseBuilder cb(n, branch_fn_names, flib, keep_node_fetchable, g);
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
TF_RETURN_IF_ERROR(cb.AddInputs());
TF_RETURN_IF_ERROR(cb.AddOutputs());
g->RemoveNode(n);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
bool keep_node_fetchable);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_

View File

@ -0,0 +1,441 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
AttrValue attr;
for (const char* name : names) {
attr.mutable_list()->add_func()->set_name(name);
}
return attr;
}
SessionOptions SessionOptionsWithInlining() {
SessionOptions session_options;
session_options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_do_function_inlining(true);
return session_options;
}
Status Rewrite(std::unique_ptr<Graph>* graph) {
FunctionLibraryDefinition flib_def((*graph)->flib_def());
GraphOptimizationPassOptions opt_options;
SessionOptions session_options = SessionOptionsWithInlining();
opt_options.session_options = &session_options;
opt_options.graph = graph;
opt_options.flib_def = &flib_def;
LowerFunctionalOpsPass pass;
return pass.Run(opt_options);
}
TEST(LowerCaseOpTest, Simple) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
// Add test functions for then and else branch.
FunctionDefLibrary f_lib_proto;
*(f_lib_proto.add_function()) = test::function::XTimesTwo();
*(f_lib_proto.add_function()) = test::function::XTimesFour();
*(f_lib_proto.add_function()) = test::function::XTimes16();
// Construct simple conditional that switches on `pred` and operates only on
// single input `A`.
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
auto branch_index =
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
Node* written_if;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
TF_ASSERT_OK(
NodeBuilder("case", "Case", &root.graph()->flib_def())
.Input(branch_index.node())
.Input(inputs)
.Attr("branches",
FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"}))
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
.Attr("Tout", {DT_INT32})
.Finalize(root.graph(), &written_if));
TF_ASSERT_OK(root.DoShapeInference(written_if));
TF_ASSERT_OK(root.ToGraph(graph.get()));
// The input graph has no switch or merge nodes.
int node_called_case_count = 0;
for (const auto* op : graph->op_nodes()) {
ASSERT_FALSE(op->IsSwitch());
ASSERT_FALSE(op->IsMerge());
if (op->name() == "case") {
++node_called_case_count;
}
}
ASSERT_EQ(node_called_case_count, 1);
TF_ASSERT_OK(Rewrite(&graph));
// Verify the resultant graph has switch and merge nodes, and a node called
// `if` (but not If nodes).
int switch_count = 0;
int merge_count = 0;
node_called_case_count = 0;
for (const auto* op : graph->op_nodes()) {
if (op->IsSwitch()) {
++switch_count;
}
if (op->IsMerge()) {
++merge_count;
}
ASSERT_NE(op->type_string(), "Case");
if (op->name() == "case") {
++node_called_case_count;
}
}
// One switch for predicate and one for input (A).
ASSERT_EQ(switch_count, 2);
// One merge for the single output value of then and else, and one more merge
// to enforce then and else function call execution (`branch_executed` node).
ASSERT_EQ(merge_count, 2);
ASSERT_EQ(node_called_case_count, 1);
// Verify execution.
ClientSession session(root, SessionOptionsWithInlining());
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(-1));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(20));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
}
}
TEST(LowerCaseOpTest, BranchFunctionsWithoutOutputs) {
using ::tensorflow::test::function::GDef;
using ::tensorflow::test::function::NDef;
using FDH = ::tensorflow::FunctionDefHelper;
// Wrap AssignAddVariable + Const into a function.
const auto assign_add = [](const string& fn_name, int v) {
const Tensor tensor = test::AsScalar<int32>(v);
return FDH::Create(
fn_name, {"v: resource"}, {}, {},
{
{{"c"}, "Const", {}, {{"value", tensor}, {"dtype", DT_INT32}}},
{{"upd"},
"AssignAddVariableOp",
{"v", "c:output"},
{{"dtype", DT_INT32}}},
},
/*ret_def=*/{},
/*control_ret_def=*/{{"side_effects", "upd"}});
};
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
// Add test functions for then and else branch.
FunctionDefLibrary f_lib_proto;
*(f_lib_proto.add_function()) = assign_add("AddOne", 1);
*(f_lib_proto.add_function()) = assign_add("AddTwo", 2);
*(f_lib_proto.add_function()) = assign_add("AddTen", 10);
// Construct a graph to represent following program:
//
// (branch_index: int32, initial_val: int32) -> (int32) {
// var = Variable(initial_value)
// switch (branch_index) {
// case 0:
// var += 1; break; # AddOne function call
// case 1:
// var += 2; break; # AddTwo function call
// default:
// var += 10; break; # AddTen function call
// }
// return var
// }
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
auto branch_index =
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
auto initial_val = ops::Placeholder(root.WithOpName("initial_val"), DT_INT32);
auto var = ops::VarHandleOp(root.WithOpName("var"), DT_INT32, {});
auto init = ops::AssignVariableOp(root.WithOpName("init"), var, initial_val);
Node* case_node;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(var.node())});
TF_ASSERT_OK(
NodeBuilder("case", "Case", &root.graph()->flib_def())
.Input(branch_index.node())
.Input(inputs)
.ControlInput(init.operation.node())
.Attr("branches", FuncListAttr({"AddOne", "AddTwo", "AddTen"}))
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
.Attr("Tout", DataTypeSlice{})
.Finalize(root.graph(), &case_node));
auto read = ops::ReadVariableOp(
root.WithOpName("read").WithControlDependencies(Output(case_node)), var,
DT_INT32);
TF_ASSERT_OK(root.DoShapeInference(case_node));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(Rewrite(&graph));
// Verify the resultant graph has switch, merge and function call nodes.
int switch_count = 0;
int merge_count = 0;
int node_called_case_count = 0;
for (const auto* op : graph->op_nodes()) {
if (op->IsSwitch()) ++switch_count;
if (op->IsMerge()) ++merge_count;
if (op->name() == "case") ++node_called_case_count;
ASSERT_NE(op->type_string(), "Case");
}
// One switch for predicate and one for input (A).
ASSERT_EQ(switch_count, 2);
// One merge for the else/then branch (`branch_executed` node).
ASSERT_EQ(merge_count, 1);
// We keep a NoOp with the same name as original If node.
ASSERT_EQ(node_called_case_count, 1);
// Verify execution.
ClientSession session(root, SessionOptionsWithInlining());
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(-5));
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 11);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(31));
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
}
}
TEST(LowerCaseOpTest, DoNotInlineLoweredFunction) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
FunctionDef x_times_two = test::function::XTimesTwo();
FunctionDef x_times_four = test::function::XTimesFour();
FunctionDef x_times_16 = test::function::XTimes16();
// Case branches can't be inlined.
(*x_times_two.mutable_attr())["_noinline"].set_b(true);
(*x_times_four.mutable_attr())["_noinline"].set_b(true);
(*x_times_16.mutable_attr())["_noinline"].set_b(true);
// Add test functions for the branches.
FunctionDefLibrary f_lib_proto;
*(f_lib_proto.add_function()) = x_times_two;
*(f_lib_proto.add_function()) = x_times_four;
*(f_lib_proto.add_function()) = x_times_16;
// Construct simple conditional that switches on `pred` and operates only on
// single input `A`.
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
auto branch_index =
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
Node* written_case;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
TF_ASSERT_OK(
NodeBuilder("case", "Case", &root.graph()->flib_def())
.Input(branch_index.node())
.Input(inputs)
.Attr("branches",
FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"}))
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
.Attr("Tout", {DT_INT32})
.Finalize(root.graph(), &written_case));
TF_ASSERT_OK(root.DoShapeInference(written_case));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(Rewrite(&graph));
// Verify that Case node was lowered but branch functions were not inlined.
int x_times_two_count = 0;
int x_times_four_count = 0;
int x_times_16_count = 0;
for (const auto* op : graph->op_nodes()) {
if (op->type_string() == x_times_two.signature().name()) {
x_times_two_count++;
}
if (op->type_string() == x_times_four.signature().name()) {
x_times_four_count++;
}
if (op->type_string() == x_times_16.signature().name()) {
x_times_16_count++;
}
ASSERT_NE(op->type_string(), "Case");
}
// One function for each branch.
ASSERT_EQ(x_times_two_count, 1);
ASSERT_EQ(x_times_four_count, 1);
ASSERT_EQ(x_times_16_count, 1);
// Verify execution.
ClientSession session(root, SessionOptionsWithInlining());
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(-2));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(branch_index.node()), Input::Initializer(31));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
}
}
} // namespace
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_case_op.h"
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/lower_while_op.h"
@ -114,13 +115,13 @@ Status LowerFunctionalOpsPass::Run(
? *keep_lowered_nodes_fetchable_
: !HasArgsOrRetvals(*g);
// Lower all If and While ops that have the `kLowerUsingSwitchMergeAttr` attr
// set and inlines all function calls into the graph.
// Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
// attr set and inline all function calls into the graph.
// We start at `i` = 2 to skip the source and sink nodes.
// Note that `g->num_node_ids()` may change in the for body if a matching If
// or While node is lowered. Since new graph nodes are always added to the
// end of the list of nodes it is ensured that nested If/While nodes will be
// lowered as well.
// Note that `g->num_node_ids()` may change in the for body if a matching If,
// Case, While node is lowered. Since new graph nodes are always added to the
// end of the list of nodes it is ensured that nested If/Case/While nodes will
// be lowered as well.
for (int i = 2; i < g->num_node_ids(); ++i) {
Node* n = g->FindNodeId(i);
if (n == nullptr) continue; // deleted node
@ -139,6 +140,9 @@ Status LowerFunctionalOpsPass::Run(
if (n->type_string() == "If") {
TF_RETURN_IF_ERROR(
RewriteIfNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
} else if (n->type_string() == "Case") {
TF_RETURN_IF_ERROR(
RewriteCaseNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
} else if (n->type_string() == "While") {
TF_RETURN_IF_ERROR(
RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable));

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
@ -89,9 +89,9 @@ class CondBuilder {
Node* then_call_node_;
Node* else_call_node_;
// Merge node that has inputs from [pivot_t, pivot_f] and control edges from
// [^then_call_node_, ^else_call_node_]. This node will guarantee that if
// then/else branch functions do not have outputs, they still will be executed
// for the side effects.
// [^then_call_node_, ^else_call_node_]. This node will guarantee that even
// when then/else branch functions do not have outputs, they still will be
// executed for the side effects.
Node* branch_executed_node_;
Graph* graph_;
const FunctionLibraryDefinition& flib_;

View File

@ -58,6 +58,7 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
*new std::unordered_map<string, Node::NodeClass>({
// Keep in same order as NodeClass values
REF_CLASS("Switch", NC_SWITCH),
REF_CLASS("_SwitchN", NC_SWITCH),
REF_CLASS("Merge", NC_MERGE),
REF_CLASS("Enter", NC_ENTER),
REF_CLASS("Exit", NC_EXIT),

View File

@ -150,7 +150,8 @@ bool IsControlFlow(const NodeDef& node) {
node.op() == "LoopCond" ||
node.op() == "Merge" ||
node.op() == "NextIteration" ||
node.op() == "Switch";
node.op() == "Switch" ||
node.op() == "_SwitchN";
// clang-format on
}
@ -523,7 +524,7 @@ bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
bool IsSwitch(const NodeDef& node) {
const auto& op = node.op();
return op == "Switch" || op == "RefSwitch";
return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
}
bool IsSymbolicGradient(const NodeDef& node) {

View File

@ -1281,7 +1281,7 @@ Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
node_map_->AddOutput(const_index->name(), output->name());
} else {
// This is a control dependency (or an invalid edge since the
// merge node has only 2 inputs): preserve them.
// merge node has only 2 outputs): preserve them.
}
}
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_case_op.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/lower_while_op.h"
@ -1157,6 +1158,8 @@ Status InlineFunctionCalls(const GrapplerItem& item,
if (n->type_string() == "If") {
TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
} else if (n->type_string() == "Case") {
TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), flib_def, false));
} else if (n->type_string() == "While") {
TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
}

View File

@ -195,6 +195,7 @@ std::set<string> GetOpsFormatAgnostic() {
"StridedSlice",
"StridedSliceGrad",
"Switch",
"_SwitchN",
"Tile",
"TruncateDiv",
"TruncateMod",
@ -1943,7 +1944,15 @@ class SwitchProcessor : public AgnosticNodeProcessor {
: AgnosticNodeProcessor(opt_cxt) {}
protected:
std::set<int> GetOutputPos() const override { return {0, 1}; }
std::set<int> GetOutputPos() const override {
std::set<int> output_pos;
const int num_outs =
node_->attr().count("num_outs") ? node_->attr().at("num_outs").i() : 2;
for (int i = 0; i < num_outs; i++) {
output_pos.insert(i);
}
return output_pos;
}
};
class TileProcessor : public AgnosticNodeProcessor {

View File

@ -460,7 +460,7 @@ std::vector<int> GetStackPushNodesToConvert(
const std::unordered_set<string> op_types_to_traverse(
{"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
"Identity", "RefIdentity"});
"_SwitchN", "Identity", "RefIdentity"});
const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
};
@ -753,6 +753,9 @@ Status LoopOptimizer::RemoveDeadBranches(
if (!IsSwitch(node)) {
continue;
}
if (node.op() == "_SwitchN") { // _SwitchN not used in loop control flow.
continue;
}
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
continue;
}
@ -798,8 +801,8 @@ Status LoopOptimizer::RemoveDeadBranches(
if (IsMerge(*dead.node)) {
const int num_data_inputs = dead.node->attr().at("N").i();
if (num_data_inputs > 2) {
// This never happens in practice, so we'll just skip these to
// simplify the code for now.
// This can happen with _SwitchN/Merge (Case lowering). We skip these
// to simplify the code for now.
found_node_to_preserve = true;
break;
}
@ -877,11 +880,15 @@ Status LoopOptimizer::RemoveDeadBranches(
}
// Remove dead data input.
const std::set<int>& dead_inputs = itr.second;
CHECK_LE(dead_inputs.size(), 1);
// (This loop would delete >1 items possibly in the wrong order.)
for (int index : dead_inputs) {
dead_node->mutable_input()->DeleteSubrange(index, 1);
}
// Turn Merge into Identity only if we deleted data inputs.
// Turn Merge into Identity only if we deleted the other data input.
if (!dead_inputs.empty()) {
const int num_data_inputs = dead_node->attr().at("N").i();
CHECK_EQ(num_data_inputs, dead_inputs.size() + 1);
dead_node->set_op("Identity");
dead_node->mutable_attr()->erase("N");
}

View File

@ -39,12 +39,30 @@ void SwitchOp::Compute(OpKernelContext* context) {
}
}
void SwitchNOp::Compute(OpKernelContext* context) {
const Tensor& output_index_t = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_index_t.shape()),
errors::InvalidArgument("The second input must be a scalar, "
"but it has shape ",
output_index_t.shape().DebugString()));
int output_index = output_index_t.scalar<int>()();
if (output_index < 0 || output_index >= num_outputs()) {
output_index = num_outputs() - 1;
}
context->set_output(output_index, context->input(0));
}
#define REGISTER_CPU_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_CPU) \
.HostMemory("pred") \
.TypeConstraint<type>("T"), \
SwitchOp)
SwitchOp) \
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
.Device(DEVICE_CPU) \
.HostMemory("output_index") \
.TypeConstraint<type>("T"), \
SwitchNOp)
#define REGISTER_CPU_REF_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
@ -58,7 +76,12 @@ void SwitchOp::Compute(OpKernelContext* context) {
.Device(DEVICE_GPU) \
.HostMemory("pred") \
.TypeConstraint<type>("T"), \
SwitchOp)
SwitchOp) \
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
.Device(DEVICE_GPU) \
.HostMemory("output_index") \
.TypeConstraint<type>("T"), \
SwitchNOp)
#define REGISTER_GPU_REF_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
@ -96,7 +119,14 @@ TF_CALL_variant(REGISTER_GPU_SWITCH);
.HostMemory("output_false") \
.HostMemory("output_true") \
.TypeConstraint<type>("T"), \
SwitchOp)
SwitchOp) \
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
.Device(DEVICE_GPU) \
.HostMemory("data") \
.HostMemory("output_index") \
.HostMemory("outputs") \
.TypeConstraint<type>("T"), \
SwitchNOp)
#define REGISTER_GPU_HOST_REF_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \

View File

@ -46,6 +46,21 @@ class SwitchOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp);
};
// An n-way switch op has two inputs and N outputs. It forwards the value of
// Input:0 to the output specified by Input:1. Input:1 is an integer tensor.
// Input:0 is forwarded to output:0 if Input:1 is 0, to output:1 if 1, and so
// forth. If Input:1 is <0 or >=num_outputs(), Input:0 is forwarded to
// output:num_outputs()-1.
class SwitchNOp : public OpKernel {
public:
explicit SwitchNOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override;
bool IsExpensive() override { return false; }
~SwitchNOp() override {}
TF_DISALLOW_COPY_AND_ASSIGN(SwitchNOp);
};
// A merge op has n inputs and two outputs. It forwards the value of the
// first input that becomes available to its first output, and the
// index of the first input to its second output.

View File

@ -24,6 +24,7 @@ using shape_inference::ShapeHandle;
// --------------------------------------------------------------------------
namespace {
Status SwitchShape(InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@ -39,6 +40,27 @@ Status SwitchShape(InferenceContext* c) {
}
return Status::OK();
}
Status SwitchNShape(InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
ShapeHandle out = c->input(0);
int num_outs;
TF_RETURN_IF_ERROR(c->GetAttr("num_outs", &num_outs));
for (int i = 0; i < num_outs; i++) {
c->set_output(i, out);
}
// Handle resource shape / dtype.
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
for (int i = 0; i < num_outs; i++) {
c->set_output_handle_shapes_and_types(i, *handle_data);
}
}
return Status::OK();
}
} // namespace
REGISTER_OP("Switch")
@ -58,6 +80,14 @@ REGISTER_OP("RefSwitch")
.SetAllowsUninitializedInput()
.SetShapeFn(SwitchShape);
REGISTER_OP("_SwitchN")
.Input("data: T")
.Input("output_index: int32")
.Output("outputs: num_outs * T")
.Attr("num_outs: int >= 1")
.Attr("T: type")
.SetShapeFn(SwitchNShape);
// --------------------------------------------------------------------------
REGISTER_OP("RefSelect")
.Input("index: int32")

View File

@ -49,6 +49,27 @@ TEST(ControlFlowOpsTest, Merge_ShapeFn) {
INFER_OK(op, "[2,1];[2,1];[2,1]", "in0;[]");
}
TEST(ControlFlowOpsTest, SwitchN_ShapeFn) {
ShapeInferenceTestOp op("_SwitchN");
int n = 5;
TF_ASSERT_OK(NodeDefBuilder("test", "_SwitchN")
.Input({"d", 0, DT_FLOAT})
.Input({"bi", 0, DT_INT32})
.Attr("num_outs", n)
.Finalize(&op.node_def));
// Non-scalar output_index.
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[2]");
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1]");
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[?]");
// The second input should always be scalar. Outputs are 5x the first input.
INFER_OK(op, "?;?", "in0;in0;in0;in0;in0");
INFER_OK(op, "[2,?];?", "in0;in0;in0;in0;in0");
INFER_OK(op, "[2,?];[]", "in0;in0;in0;in0;in0");
INFER_OK(op, "[2,3];[]", "in0;in0;in0;in0;in0");
}
TEST(ControlFlowOpsTest, RefSelect_ShapeFn) {
ShapeInferenceTestOp op("RefSelect");

View File

@ -1177,6 +1177,84 @@ class FunctionalOpsCaseTest(test.TestCase):
self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default
self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default
@test_util.run_deprecated_v1
@test_util.disable_xla("Don't lower for XLA")
def testSkipEagerCaseLoweringPreservesNameForFetch(self):
for use_gpu in (True, False):
def Run(branch, x, fetch_by_name, use_gpu=use_gpu):
with ops.Graph().as_default() as g:
@function.Defun(dtypes.float32)
def two(x):
return -1, x * 2
@function.Defun(dtypes.float32)
def three(x):
return 0, x * 3
@function.Defun(dtypes.float32)
def four(x):
return 1, x * 4
outputs = gen_functional_ops.case(branch, input=[x],
Tout=[dtypes.int32, dtypes.float32],
branches=[two, three, four],
name="my_case")
# `outputs` is the list of output tensors of the Case op. We
# arbitrarily choose the 0th tensor to get the Case op and set the
# lowering attribute on it.
outputs[0].op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=True))
outputs = array_ops.identity_n(outputs)
with self.session(graph=g, use_gpu=use_gpu) as sess:
return sess.run("my_case:1" if fetch_by_name else outputs[1])
self.assertAllEqual(2 * 1., Run(0, 1., False))
self.assertAllEqual(2 * 1., Run(0, 1., True))
self.assertAllEqual(3 * 7., Run(1, 7., False))
self.assertAllEqual(3 * 7., Run(1, 7., True))
self.assertAllEqual(4 * -3., Run(2, -3., False))
self.assertAllEqual(4 * -3., Run(2, -3., True))
self.assertAllEqual(4 * -4., Run(7, -4., False)) # >= N default
self.assertAllEqual(4 * -4., Run(7, -4., True)) # >= N default
self.assertAllEqual(4 * -5., Run(-1, -5., False)) # <0 default
self.assertAllEqual(4 * -5., Run(-1, -5., True)) # <0 default
@test_util.disable_xla("Don't lower for XLA")
def testCaseLowering(self):
for use_gpu in (True, False):
@eager_function.defun
def Run(branch, x):
@function.Defun(dtypes.float32)
def two(x):
return -1, x * 2
@function.Defun(dtypes.float32)
def three(x):
return 0, x * 3
@function.Defun(dtypes.float32)
def four(x):
return 1, x * 4
outputs = gen_functional_ops.case(branch, input=[x],
Tout=[dtypes.int32, dtypes.float32],
branches=[two, three, four])
# `outputs` is the list of output tensors of the Case op. We
# arbitrarily choose the 0th tensor to get the Case op and set the
# lowering attribute on it.
outputs[0].op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=True))
outputs = array_ops.identity_n(outputs)
return outputs[1]
with ops.device(test.gpu_device_name() if use_gpu else "CPU:0"):
self.assertAllEqual(2 * 1., self.evaluate(Run(0, 1.)))
self.assertAllEqual(3 * 7., self.evaluate(Run(1, 7.)))
self.assertAllEqual(4 * -3., self.evaluate(Run(2, -3.)))
self.assertAllEqual(4 * -4., self.evaluate(Run(7, -4.))) # >=N default
self.assertAllEqual(4 * -5., self.evaluate(Run(-1, -5.))) # <0 default
if __name__ == "__main__":
test.main()

View File

@ -510,7 +510,7 @@ def _make_output_composite_tensors_match(op_type, branch_graphs):
raise TypeError(
"Cannot reconcile {op_name} {output_idx}-th outputs:\n"
" outputs from all branches: {outputs}".format(
op_name="tf.cond" if op_type == _COND else "tf.case",
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
output_idx=output_idx,
outputs=branch_outs))
@ -534,9 +534,9 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
for output_idx, branch_outs in enumerate(
zip(*branch_outputs_flat_with_composites)):
if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n"
raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
" branches returned: {outputs}".format(
op_name="tf.cond" if op_type == _COND else "tf.case",
op_name="cond" if op_type == _COND else "switch_case",
output_idx=output_idx,
outputs=branch_outs))
if isinstance(branch_outs[0], ops.IndexedSlices):
@ -632,7 +632,7 @@ def _check_same_outputs(op_type, graphs):
b0_name="true_fn" if op_type == _COND else "branches[0]",
bn_name=("false_fn" if op_type == _COND else
"branches[{}]".format(branch_idx)),
op_name="tf.cond" if op_type == _COND else "tf.case",
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
b0_out=graphs[0].structured_outputs,
bn_out=graphs[branch_idx].structured_outputs,
detail=error_detail))
@ -817,8 +817,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"):
@ops.RegisterGradient("Case")
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of a Case op produced (w/ branch_index) by tf.case."""
# Get the if operator (this logic handles the case where op is a MockOp)
"""The gradient of a Case op produced by tf.switch_case."""
# Get the Case operator (this logic handles the case where op is a MockOp)
case_op = op.outputs[0].op
branch_graphs = _get_func_graphs(case_op)
assert branch_graphs
@ -892,7 +892,7 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
branches_grad_inputs)
branches_grad_inputs, name="gradient")
# The predicate has no gradient.
return [None] + outputs
@ -935,14 +935,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
name=name)
# TODO(b/110167197) this requires Case to have at least 1 output
# TODO(b/110167197): this requires Case to have at least 1 output
case_op = tensors[0].op
# TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode.
# util.maybe_set_lowering_attr(case_op)
util.maybe_set_lowering_attr(case_op)
util.maybe_propagate_compile_time_consts_in_xla(case_op)
# Return identities for each output of the Case op, rather than the output of
# the Case op directly. This makes pruning work if the output of select_case()
# the Case op directly. This makes pruning work if the output of switch_case()
# is fetched: the lowering pass converts the Case outputs into IdentityN
# outputs, which if fetched will cause all ops in the taken branch to be run
# (since it takes all merge ops as input). After lowering, each output

View File

@ -19,11 +19,14 @@ from __future__ import division
from __future__ import print_function
import collections
from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -39,7 +42,9 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
@ -933,7 +938,10 @@ class DataTypesTest(test_util.TensorFlowTestCase):
@test_util.run_all_in_graph_and_eager_modes
class IndexedCaseTest(test_util.TensorFlowTestCase):
class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def make_name(self):
return self.id().split(".")[-1].replace("(", "_").replace(")", "")
def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self):
nbranches = 5
@ -947,53 +955,55 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
case_out = control_flow_ops.switch_case(branch_index, branches)
self.assertEqual(bi * 10, self.evaluate(case_out))
def testCase(self):
@parameterized.parameters((0,), (2,), (3,))
def testCase(self, bi):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
for bi in 0, 2, 3:
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(branch_index, branches)
self.assertEqual(bi * 10., self.evaluate(case_out))
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, name=self.make_name())
self.assertEqual(bi * 10., self.evaluate(case_out))
def testCase_withDefault(self):
@parameterized.parameters((-1,), (2,), (4,), (5,), (6,))
def testCase_withDefault(self, bi):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
for bi in -1, 2, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6))
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6), name=self.make_name())
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
def testCase_dictWithDefault(self):
@parameterized.parameters((-1,), (0,), (3,), (5,))
def testCase_dictWithDefault(self, bi):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {i: make_func(i) for i in range(nbranches)}
for bi in -1, 0, 3, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6))
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6), name=self.make_name())
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
def testCase_gradient(self):
@parameterized.parameters((-1,), (1,), (4,), (5,))
def testCase_gradient(self, bi):
nbranches = 5
inputs = [
array_ops.constant(float(bi), name="br{}_in".format(bi))
@ -1005,22 +1015,22 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
branches = {bi: make_func(bi) for bi in range(nbranches)}
for bi in -1, 1, 4, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(branch_index, branches)
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
self.assertEqual(expected, self.evaluate(actual))
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(branch_index, branches)
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
self.assertEqual(expected, self.evaluate(actual))
def testCase_gradient_diffShapedIntermediates(self):
@parameterized.parameters((-2,), (2,), (5,))
def testCase_gradient_diffShapedIntermediates(self, bi):
nbranches = 5
inputs = [
array_ops.constant(
@ -1038,34 +1048,105 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
branches = {bi: make_func(bi) for bi in range(nbranches)}
for bi in -1, 2, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(branch_index, branches)
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
expected_grads = []
for input_idx in range(nbranches):
if used_bi == input_idx:
with backprop.GradientTape() as tape:
tape.watch(inputs[used_bi])
y = make_func(used_bi)()
expected_grads.append(
self.evaluate(
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
else:
expected_grads.append(None if context.executing_eagerly() else [0.] *
(input_idx + 1))
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(
branch_index, branches, name=self.make_name())
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
expected_grads = []
for input_idx in range(nbranches):
if used_bi == input_idx:
with backprop.GradientTape() as tape:
tape.watch(inputs[used_bi])
y = make_func(used_bi)()
expected_grads.append(
self.evaluate(
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
else:
expected_grads.append(None if context.executing_eagerly() else [0.] *
(input_idx + 1))
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
if expected is None:
self.assertIsNone(actual)
else:
self.assertAllEqual(expected, self.evaluate(actual))
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
if expected is None:
self.assertIsNone(actual)
else:
self.assertAllEqual(expected, self.evaluate(actual))
@test_util.run_gpu_only
@test_util.disable_xla("Wants RunMetadata")
def testParallelExecution(self):
"""Verify disjoint branches across while iterations are run in parallel."""
with ops.Graph().as_default() as g:
nbranches = 7
matrices = array_ops.unstack( # Ensure all are ready before while.
array_ops.matrix_diag(
random_ops.random_uniform([nbranches, 8, 512]) + 1e-3))
def make_branch(i, mat, name):
def branch_fn():
next_i = i + 1
with ops.device("gpu:0"):
return next_i, math_ops.reduce_sum(
linalg_ops.cholesky(mat, name=name + "_Cholesky"))
return branch_fn
def make_branches(i):
return [make_branch(i, matrices[bi], "br{}".format(bi))
for bi in range(nbranches)]
def cond(i, _):
return i < nbranches
def body(i, result):
with ops.device("cpu:0"):
next_i, branch_out = control_flow_ops.switch_case(i, make_branches(i))
return next_i, result + branch_out
_, result = control_flow_ops.while_loop(cond, body, [0, 0.])
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
config = config_pb2.ConfigProto(
allow_soft_placement=False, log_device_placement=True)
with session.Session(config=config, graph=g) as sess:
_ = sess.run(result, options=run_options, run_metadata=run_metadata)
chol_node_stats = []
for dev_stats in run_metadata.step_stats.dev_stats:
for node_stats in dev_stats.node_stats:
if (node_stats.node_name.endswith("Cholesky") and
node_stats.all_start_nanos > 0):
chol_node_stats.append(node_stats)
self.assertLen(chol_node_stats, nbranches)
chol_node_stats = sorted(chol_node_stats, key=lambda stats: stats.node_name)
op_start_nanos = [
stats.all_start_nanos for stats in chol_node_stats
]
op_end_nanos = [
stats.all_start_nanos + stats.op_end_rel_nanos
for stats in chol_node_stats
]
def overlap(range1, range2):
s1, e1 = range1
s2, e2 = range2
if s1 < s2:
return 0 if s2 > e1 else e1 - s2
return 0 if s1 > e2 else e2 - s1
timespans = list(zip(op_start_nanos, op_end_nanos))
overlaps_chol0 = [overlap(timespans[0], r2) for r2 in timespans[1:]]
# There are nbranches-1 overlaps, sometimes all nonzero, but we
# conservatively check for at least one here, to avoid test flakiness.
self.assertGreater(np.count_nonzero(overlaps_chol0), 0)
def testCase_validateIndicesContiguous(self):