Adds a lowering from Case to _SwitchN+Merge.
Introduces a new n-way _SwitchN op+kernel. I audited usages of the grappler and graph variants of IsSwitch and IsMerge, and believe the corrections in this CL are correct. PiperOrigin-RevId: 250803634
This commit is contained in:
parent
586ff7b1fa
commit
10ed2f7bb5
@ -846,24 +846,42 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
|
||||
const Edge* pred_edge;
|
||||
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
|
||||
|
||||
Predicate* true_switch;
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
pred_edge->src(), pred_edge->src_output(),
|
||||
/*must_be_true=*/true, &true_switch));
|
||||
if (n->num_outputs() == 2) {
|
||||
Predicate* true_switch;
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
pred_edge->src(), pred_edge->src_output(),
|
||||
/*must_be_true=*/true, &true_switch));
|
||||
|
||||
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
|
||||
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
|
||||
|
||||
// Output 0 is alive iff all inputs are alive and the condition is false.
|
||||
input_preds.push_back(false_switch);
|
||||
SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
|
||||
should_revisit);
|
||||
input_preds.pop_back();
|
||||
// Output 0 is alive iff all inputs are alive and the condition is false.
|
||||
input_preds.push_back(false_switch);
|
||||
SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
|
||||
should_revisit);
|
||||
input_preds.pop_back();
|
||||
|
||||
// Output 1 is alive iff all inputs are alive and the condition is true.
|
||||
input_preds.push_back(true_switch);
|
||||
SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
|
||||
should_revisit);
|
||||
input_preds.pop_back();
|
||||
// Output 1 is alive iff all inputs are alive and the condition is true.
|
||||
input_preds.push_back(true_switch);
|
||||
SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
|
||||
should_revisit);
|
||||
input_preds.pop_back();
|
||||
} else { // N-way switch case. Exactly one of N branches is alive.
|
||||
Predicate* branch_pred;
|
||||
for (int i = 0; i < n->num_outputs() - 1; i++) {
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
n, i, /*must_be_true=*/false, &branch_pred));
|
||||
input_preds.push_back(branch_pred);
|
||||
SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds),
|
||||
should_revisit);
|
||||
input_preds.pop_back();
|
||||
input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred));
|
||||
}
|
||||
// The default (last) branch does not need its own symbol, is simply the
|
||||
// nor of all other branches.
|
||||
SetPredicate(n, n->num_outputs() - 1,
|
||||
predicate_factory_.MakeAndPredicate(input_preds),
|
||||
should_revisit);
|
||||
}
|
||||
|
||||
// Control is alive iff all inputs are alive.
|
||||
SetPredicate(n, Graph::kControlSlot,
|
||||
|
@ -3171,6 +3171,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
|
||||
"common_runtime/local_device.h",
|
||||
"common_runtime/lower_function_call_op.h",
|
||||
"common_runtime/lower_if_op.h",
|
||||
"common_runtime/lower_case_op.h",
|
||||
"common_runtime/lower_functional_ops.h",
|
||||
"common_runtime/lower_while_op.h",
|
||||
"common_runtime/memory_types.h",
|
||||
@ -3232,6 +3233,7 @@ tf_cuda_library(
|
||||
"common_runtime/inspecting_placer.h",
|
||||
"common_runtime/isolate_placer_inspection_required_ops_pass.cc",
|
||||
"common_runtime/local_device.cc",
|
||||
"common_runtime/lower_case_op.cc",
|
||||
"common_runtime/lower_function_call_op.cc",
|
||||
"common_runtime/lower_functional_ops.cc",
|
||||
"common_runtime/lower_if_op.cc",
|
||||
@ -5326,6 +5328,30 @@ tf_cc_tests(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "common_runtime_lower_case_op_test",
|
||||
size = "small",
|
||||
srcs = ["common_runtime/lower_case_op_test.cc"],
|
||||
deps = [
|
||||
":all_kernels",
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":direct_session",
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:client_session",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "common_runtime_lower_while_op_test",
|
||||
size = "small",
|
||||
|
300
tensorflow/core/common_runtime/lower_case_op.cc
Normal file
300
tensorflow/core/common_runtime/lower_case_op.cc
Normal file
@ -0,0 +1,300 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using NodeOut = NodeBuilder::NodeOut;
|
||||
|
||||
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||
|
||||
// Convenience builder to make it easy to construct a case with a single
|
||||
// function call in each branch. This first converts the Case node
|
||||
// into switches (for inputs) and merges (for outputs) around a function call
|
||||
// per branch.
|
||||
class CaseBuilder {
|
||||
public:
|
||||
// Create a CaseBuilder to create the lowered form of `case` with branch
|
||||
// functions identified by `branch_fn_names` in the `graph`. The functions
|
||||
// should be available in `flib`.
|
||||
CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names,
|
||||
const FunctionLibraryDefinition& flib, bool keep_node_fetchable,
|
||||
Graph* graph);
|
||||
|
||||
// Constructs the basic conditional control flow using switch and merge nodes.
|
||||
Status CreatePivotNodes();
|
||||
|
||||
// Adds the inputs from the if node to the merge nodes of the lowered if.
|
||||
Status AddInputs();
|
||||
|
||||
// Adds the outputs from the if node to the merge nodes of the lowered if.
|
||||
// Note: no inputs can be added once outputs are added as the then and else
|
||||
// nodes are finalized while adding outputs.
|
||||
Status AddOutputs();
|
||||
|
||||
// Builds an identity node with the same outputs as Case.
|
||||
Status BuildLoweredCaseOutput();
|
||||
|
||||
private:
|
||||
// Returns unique name containing the name of the Case op being rewritten
|
||||
// (name_), infix and a suffix to ensure it is unique within the graph.
|
||||
string NewName(const string& infix);
|
||||
|
||||
// Adds input to both the then and else nodes from src:src_output.
|
||||
Status AddInput(Node* src, int src_output);
|
||||
|
||||
// The merged outputs of the then and else nodes.
|
||||
std::vector<NodeOut> outputs_;
|
||||
|
||||
// The node that dominates all execution of the then and else body nodes.
|
||||
Node* control_predecessor_;
|
||||
// The original Case op.
|
||||
Node* case_op_;
|
||||
// The node with the same name as the original Case op:
|
||||
// (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
|
||||
// and if the original Case op had non-zero data outputs.
|
||||
// (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
|
||||
Node* lowered_case_output_;
|
||||
// The branch selector of the case.
|
||||
OutputTensor branch_index_;
|
||||
int num_branches_;
|
||||
// Nodes corresponding to pivot branch of branch_index _SwitchN, which is
|
||||
// the pivot node that dominates all nodes in the i'th branch.
|
||||
std::vector<Node*> pivots_;
|
||||
std::vector<Node*> call_nodes_;
|
||||
// Merge node that has inputs from each of pivots_ and control edges from
|
||||
// [^call_node for call_node in call_nodes_]. This node will guarantee that
|
||||
// even when branch functions do not have outputs, they still will be executed
|
||||
// for the side effects.
|
||||
Node* branch_executed_node_;
|
||||
Graph* graph_;
|
||||
const FunctionLibraryDefinition& flib_;
|
||||
string name_;
|
||||
bool keep_node_fetchable_;
|
||||
|
||||
NodeDebugInfo debug_info_;
|
||||
std::vector<NodeBuilder> branch_call_builders_;
|
||||
};
|
||||
|
||||
CaseBuilder::CaseBuilder(Node* case_op,
|
||||
const std::vector<string>& branch_fn_names,
|
||||
const FunctionLibraryDefinition& flib,
|
||||
bool keep_node_fetchable, Graph* graph)
|
||||
: case_op_(case_op),
|
||||
num_branches_(branch_fn_names.size()),
|
||||
graph_(graph),
|
||||
flib_(flib),
|
||||
name_(case_op->name()),
|
||||
keep_node_fetchable_(keep_node_fetchable),
|
||||
debug_info_(*case_op_) {
|
||||
branch_call_builders_.reserve(num_branches_);
|
||||
for (int b = 0; b < num_branches_; b++) {
|
||||
branch_call_builders_.emplace_back(NewName(strings::StrCat("branch", b)),
|
||||
branch_fn_names[b], graph->op_registry(),
|
||||
&debug_info_);
|
||||
branch_call_builders_[b].Device(case_op_->requested_device());
|
||||
branch_call_builders_[b].Attr(kLowerAsMultiDeviceFunctionAttr, true);
|
||||
}
|
||||
TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_));
|
||||
}
|
||||
|
||||
Status CaseBuilder::CreatePivotNodes() {
|
||||
// Construct the basic case body (consisting of feeding in the val to
|
||||
// create pivot nodes).
|
||||
Node* branch_index;
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_index"), "_SwitchN",
|
||||
graph_->op_registry(), &debug_info_)
|
||||
.Input(NodeOut(branch_index_))
|
||||
.Input(NodeOut(branch_index_))
|
||||
.Attr("num_outs", num_branches_)
|
||||
.Device(case_op_->requested_device())
|
||||
.Finalize(graph_, &branch_index));
|
||||
control_predecessor_ = branch_index;
|
||||
pivots_.resize(num_branches_, nullptr);
|
||||
for (int b = 0; b < num_branches_; b++) {
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName(strings::StrCat("pivot_", b)),
|
||||
"Identity", graph_->op_registry(),
|
||||
&debug_info_)
|
||||
.Input(branch_index, b)
|
||||
.Device(case_op_->requested_device())
|
||||
.Finalize(graph_, &pivots_[b]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string CaseBuilder::NewName(const string& infix) {
|
||||
return graph_->NewName(strings::StrCat(name_, "/", infix));
|
||||
}
|
||||
|
||||
Status CaseBuilder::AddInput(Node* src, int src_output) {
|
||||
Node* input;
|
||||
NodeDebugInfo debug_info(*src);
|
||||
// Colocate the Switch node with the `src` node.
|
||||
//
|
||||
// This is to avoid unnecessary Host<->Device copies between src and the
|
||||
// _SwitchN node. This aligns with the implementation of legacy tf.cond in
|
||||
// control_flow_ops.py. The legacy impl colocates the Switch with the
|
||||
// input tensor which resets the device stack and forces the Switch to have
|
||||
// the same device as the input node (if set) and sets the colocation _class
|
||||
// attr. It also ignores the existing colocation constraints on the input node
|
||||
// using colocate_with(ignore_existing=True).
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "_SwitchN",
|
||||
graph_->op_registry(), &debug_info)
|
||||
.Input(src, src_output)
|
||||
.Input(branch_index_)
|
||||
.Device(src->requested_device())
|
||||
.Attr("_class", {src->name()})
|
||||
.Attr("num_outs", num_branches_)
|
||||
.Finalize(graph_, &input));
|
||||
for (int b = 0; b < num_branches_; b++) {
|
||||
branch_call_builders_[b].Input(input, b);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CaseBuilder::AddInputs() {
|
||||
// Add input data edges.
|
||||
std::vector<const Edge*> edges;
|
||||
TF_RETURN_IF_ERROR(case_op_->input_edges(&edges));
|
||||
// Start at index 1 as the first input is the branch index.
|
||||
for (int i = 1; i < edges.size(); ++i) {
|
||||
const Edge* e = edges[i];
|
||||
TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
|
||||
}
|
||||
// Add input control edges.
|
||||
for (const Edge* e : case_op_->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
graph_->AddControlEdge(e->src(), control_predecessor_);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CaseBuilder::AddOutputs() {
|
||||
// Construct the call nodes for each branch.
|
||||
call_nodes_.resize(num_branches_, nullptr);
|
||||
for (int b = 0; b < num_branches_; b++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
branch_call_builders_[b].Finalize(graph_, &call_nodes_[b]));
|
||||
graph_->AddControlEdge(pivots_[b], call_nodes_[b]);
|
||||
}
|
||||
|
||||
// Merge the outputs from the N branches (all branches have matching outputs).
|
||||
const int num_outputs = call_nodes_[0]->num_outputs();
|
||||
std::vector<Node*> merges(num_outputs);
|
||||
outputs_.resize(merges.size());
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
std::vector<NodeOut> merge_input;
|
||||
merge_input.reserve(num_branches_);
|
||||
for (int j = 0; j < num_branches_; j++) {
|
||||
merge_input.emplace_back(call_nodes_[j], i);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName("merge"), "Merge",
|
||||
graph_->op_registry(), &debug_info_)
|
||||
.Input(merge_input)
|
||||
.Device(case_op_->requested_device())
|
||||
.Finalize(graph_, &merges[i]));
|
||||
outputs_[i] = NodeOut(merges[i], 0);
|
||||
}
|
||||
|
||||
// Add a Merge node that will be used as a control dependency source for the
|
||||
// lowered output node. This Merge node will guarantee that lowered else/then
|
||||
// function calls will be executed even if they do not have data outputs.
|
||||
//
|
||||
// Furthermore it will guarantee that all function side effects will be
|
||||
// executed, if the function will be inlined into the graph. Having data
|
||||
// outputs is not enough, because they might become unused after inlining.
|
||||
//
|
||||
// We will use this node to rewrite outgoing control edges from lowered 'Case'
|
||||
// node. All data edges will read tensors directly from Merge nodes.
|
||||
std::vector<NodeOut> pivots(num_branches_);
|
||||
for (int j = 0; j < num_branches_; j++) {
|
||||
pivots[j] = NodeOut(pivots_[j]);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge",
|
||||
graph_->op_registry(), &debug_info_)
|
||||
.Input(pivots)
|
||||
.ControlInputs(call_nodes_)
|
||||
.Device(case_op_->requested_device())
|
||||
.Finalize(graph_, &branch_executed_node_));
|
||||
|
||||
TF_RETURN_IF_ERROR(BuildLoweredCaseOutput());
|
||||
|
||||
// Add outputs.
|
||||
for (const Edge* e : case_op_->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
graph_->AddControlEdge(branch_executed_node_, e->dst());
|
||||
} else {
|
||||
// Feed the outputs directly from the merge nodes so that downstream ops
|
||||
// can start before all the outputs have been computed.
|
||||
graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CaseBuilder::BuildLoweredCaseOutput() {
|
||||
// If outputs are empty, it means that we might have only output control
|
||||
// edges (already connected to the `branch_executed_node`). Furthermore it's
|
||||
// illegal to have an IdentityN with empty inputs.
|
||||
//
|
||||
// We still must keep lowered Case node as a valid source of control edges,
|
||||
// because it might be a part of function control output set.
|
||||
NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
|
||||
? NodeBuilder(name_, "IdentityN").Input(outputs_)
|
||||
: NodeBuilder(name_, "NoOp");
|
||||
return builder.Device(case_op_->requested_device())
|
||||
.ControlInput(branch_executed_node_)
|
||||
.Finalize(graph_, &lowered_case_output_);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
|
||||
bool keep_node_fetchable) {
|
||||
VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable
|
||||
<< "): " << SummarizeNode(*n);
|
||||
const AttrValue* branches_attr = n->attrs().Find("branches");
|
||||
if (branches_attr == nullptr) {
|
||||
return errors::InvalidArgument("branch functions missing");
|
||||
}
|
||||
|
||||
int num_branches = branches_attr->list().func_size();
|
||||
std::vector<string> branch_fn_names;
|
||||
branch_fn_names.reserve(num_branches);
|
||||
for (int b = 0; b < num_branches; b++) {
|
||||
branch_fn_names.emplace_back(branches_attr->list().func(b).name());
|
||||
}
|
||||
CaseBuilder cb(n, branch_fn_names, flib, keep_node_fetchable, g);
|
||||
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
|
||||
TF_RETURN_IF_ERROR(cb.AddInputs());
|
||||
TF_RETURN_IF_ERROR(cb.AddOutputs());
|
||||
g->RemoveNode(n);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
30
tensorflow/core/common_runtime/lower_case_op.h
Normal file
30
tensorflow/core/common_runtime/lower_case_op.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
|
||||
Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
|
||||
bool keep_node_fetchable);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
441
tensorflow/core/common_runtime/lower_case_op_test.cc
Normal file
441
tensorflow/core/common_runtime/lower_case_op_test.cc
Normal file
@ -0,0 +1,441 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
|
||||
AttrValue attr;
|
||||
for (const char* name : names) {
|
||||
attr.mutable_list()->add_func()->set_name(name);
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
|
||||
SessionOptions SessionOptionsWithInlining() {
|
||||
SessionOptions session_options;
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_do_function_inlining(true);
|
||||
return session_options;
|
||||
}
|
||||
|
||||
Status Rewrite(std::unique_ptr<Graph>* graph) {
|
||||
FunctionLibraryDefinition flib_def((*graph)->flib_def());
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
SessionOptions session_options = SessionOptionsWithInlining();
|
||||
opt_options.session_options = &session_options;
|
||||
opt_options.graph = graph;
|
||||
opt_options.flib_def = &flib_def;
|
||||
LowerFunctionalOpsPass pass;
|
||||
return pass.Run(opt_options);
|
||||
}
|
||||
|
||||
TEST(LowerCaseOpTest, Simple) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
// Add test functions for then and else branch.
|
||||
FunctionDefLibrary f_lib_proto;
|
||||
*(f_lib_proto.add_function()) = test::function::XTimesTwo();
|
||||
*(f_lib_proto.add_function()) = test::function::XTimesFour();
|
||||
*(f_lib_proto.add_function()) = test::function::XTimes16();
|
||||
|
||||
// Construct simple conditional that switches on `pred` and operates only on
|
||||
// single input `A`.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||
auto branch_index =
|
||||
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
|
||||
Node* written_if;
|
||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("case", "Case", &root.graph()->flib_def())
|
||||
.Input(branch_index.node())
|
||||
.Input(inputs)
|
||||
.Attr("branches",
|
||||
FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"}))
|
||||
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||
.Attr("Tout", {DT_INT32})
|
||||
.Finalize(root.graph(), &written_if));
|
||||
TF_ASSERT_OK(root.DoShapeInference(written_if));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
// The input graph has no switch or merge nodes.
|
||||
int node_called_case_count = 0;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
ASSERT_FALSE(op->IsSwitch());
|
||||
ASSERT_FALSE(op->IsMerge());
|
||||
if (op->name() == "case") {
|
||||
++node_called_case_count;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(node_called_case_count, 1);
|
||||
|
||||
TF_ASSERT_OK(Rewrite(&graph));
|
||||
|
||||
// Verify the resultant graph has switch and merge nodes, and a node called
|
||||
// `if` (but not If nodes).
|
||||
int switch_count = 0;
|
||||
int merge_count = 0;
|
||||
node_called_case_count = 0;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
if (op->IsSwitch()) {
|
||||
++switch_count;
|
||||
}
|
||||
if (op->IsMerge()) {
|
||||
++merge_count;
|
||||
}
|
||||
ASSERT_NE(op->type_string(), "Case");
|
||||
if (op->name() == "case") {
|
||||
++node_called_case_count;
|
||||
}
|
||||
}
|
||||
// One switch for predicate and one for input (A).
|
||||
ASSERT_EQ(switch_count, 2);
|
||||
// One merge for the single output value of then and else, and one more merge
|
||||
// to enforce then and else function call execution (`branch_executed` node).
|
||||
ASSERT_EQ(merge_count, 2);
|
||||
ASSERT_EQ(node_called_case_count, 1);
|
||||
|
||||
// Verify execution.
|
||||
ClientSession session(root, SessionOptionsWithInlining());
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(-1));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(20));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LowerCaseOpTest, BranchFunctionsWithoutOutputs) {
|
||||
using ::tensorflow::test::function::GDef;
|
||||
using ::tensorflow::test::function::NDef;
|
||||
using FDH = ::tensorflow::FunctionDefHelper;
|
||||
|
||||
// Wrap AssignAddVariable + Const into a function.
|
||||
const auto assign_add = [](const string& fn_name, int v) {
|
||||
const Tensor tensor = test::AsScalar<int32>(v);
|
||||
return FDH::Create(
|
||||
fn_name, {"v: resource"}, {}, {},
|
||||
{
|
||||
{{"c"}, "Const", {}, {{"value", tensor}, {"dtype", DT_INT32}}},
|
||||
{{"upd"},
|
||||
"AssignAddVariableOp",
|
||||
{"v", "c:output"},
|
||||
{{"dtype", DT_INT32}}},
|
||||
},
|
||||
/*ret_def=*/{},
|
||||
/*control_ret_def=*/{{"side_effects", "upd"}});
|
||||
};
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
// Add test functions for then and else branch.
|
||||
FunctionDefLibrary f_lib_proto;
|
||||
*(f_lib_proto.add_function()) = assign_add("AddOne", 1);
|
||||
*(f_lib_proto.add_function()) = assign_add("AddTwo", 2);
|
||||
*(f_lib_proto.add_function()) = assign_add("AddTen", 10);
|
||||
|
||||
// Construct a graph to represent following program:
|
||||
//
|
||||
// (branch_index: int32, initial_val: int32) -> (int32) {
|
||||
// var = Variable(initial_value)
|
||||
// switch (branch_index) {
|
||||
// case 0:
|
||||
// var += 1; break; # AddOne function call
|
||||
// case 1:
|
||||
// var += 2; break; # AddTwo function call
|
||||
// default:
|
||||
// var += 10; break; # AddTen function call
|
||||
// }
|
||||
// return var
|
||||
// }
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||
|
||||
auto branch_index =
|
||||
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
|
||||
auto initial_val = ops::Placeholder(root.WithOpName("initial_val"), DT_INT32);
|
||||
|
||||
auto var = ops::VarHandleOp(root.WithOpName("var"), DT_INT32, {});
|
||||
auto init = ops::AssignVariableOp(root.WithOpName("init"), var, initial_val);
|
||||
|
||||
Node* case_node;
|
||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(var.node())});
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("case", "Case", &root.graph()->flib_def())
|
||||
.Input(branch_index.node())
|
||||
.Input(inputs)
|
||||
.ControlInput(init.operation.node())
|
||||
.Attr("branches", FuncListAttr({"AddOne", "AddTwo", "AddTen"}))
|
||||
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||
.Attr("Tout", DataTypeSlice{})
|
||||
.Finalize(root.graph(), &case_node));
|
||||
|
||||
auto read = ops::ReadVariableOp(
|
||||
root.WithOpName("read").WithControlDependencies(Output(case_node)), var,
|
||||
DT_INT32);
|
||||
|
||||
TF_ASSERT_OK(root.DoShapeInference(case_node));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(Rewrite(&graph));
|
||||
|
||||
// Verify the resultant graph has switch, merge and function call nodes.
|
||||
int switch_count = 0;
|
||||
int merge_count = 0;
|
||||
int node_called_case_count = 0;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
if (op->IsSwitch()) ++switch_count;
|
||||
if (op->IsMerge()) ++merge_count;
|
||||
if (op->name() == "case") ++node_called_case_count;
|
||||
ASSERT_NE(op->type_string(), "Case");
|
||||
}
|
||||
// One switch for predicate and one for input (A).
|
||||
ASSERT_EQ(switch_count, 2);
|
||||
// One merge for the else/then branch (`branch_executed` node).
|
||||
ASSERT_EQ(merge_count, 1);
|
||||
// We keep a NoOp with the same name as original If node.
|
||||
ASSERT_EQ(node_called_case_count, 1);
|
||||
|
||||
// Verify execution.
|
||||
ClientSession session(root, SessionOptionsWithInlining());
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(-5));
|
||||
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
|
||||
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 11);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
|
||||
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
|
||||
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(31));
|
||||
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LowerCaseOpTest, DoNotInlineLoweredFunction) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
FunctionDef x_times_two = test::function::XTimesTwo();
|
||||
FunctionDef x_times_four = test::function::XTimesFour();
|
||||
FunctionDef x_times_16 = test::function::XTimes16();
|
||||
|
||||
// Case branches can't be inlined.
|
||||
(*x_times_two.mutable_attr())["_noinline"].set_b(true);
|
||||
(*x_times_four.mutable_attr())["_noinline"].set_b(true);
|
||||
(*x_times_16.mutable_attr())["_noinline"].set_b(true);
|
||||
|
||||
// Add test functions for the branches.
|
||||
FunctionDefLibrary f_lib_proto;
|
||||
*(f_lib_proto.add_function()) = x_times_two;
|
||||
*(f_lib_proto.add_function()) = x_times_four;
|
||||
*(f_lib_proto.add_function()) = x_times_16;
|
||||
|
||||
// Construct simple conditional that switches on `pred` and operates only on
|
||||
// single input `A`.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||
auto branch_index =
|
||||
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
|
||||
Node* written_case;
|
||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("case", "Case", &root.graph()->flib_def())
|
||||
.Input(branch_index.node())
|
||||
.Input(inputs)
|
||||
.Attr("branches",
|
||||
FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"}))
|
||||
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||
.Attr("Tout", {DT_INT32})
|
||||
.Finalize(root.graph(), &written_case));
|
||||
TF_ASSERT_OK(root.DoShapeInference(written_case));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(Rewrite(&graph));
|
||||
|
||||
// Verify that Case node was lowered but branch functions were not inlined.
|
||||
int x_times_two_count = 0;
|
||||
int x_times_four_count = 0;
|
||||
int x_times_16_count = 0;
|
||||
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
if (op->type_string() == x_times_two.signature().name()) {
|
||||
x_times_two_count++;
|
||||
}
|
||||
if (op->type_string() == x_times_four.signature().name()) {
|
||||
x_times_four_count++;
|
||||
}
|
||||
if (op->type_string() == x_times_16.signature().name()) {
|
||||
x_times_16_count++;
|
||||
}
|
||||
ASSERT_NE(op->type_string(), "Case");
|
||||
}
|
||||
|
||||
// One function for each branch.
|
||||
ASSERT_EQ(x_times_two_count, 1);
|
||||
ASSERT_EQ(x_times_four_count, 1);
|
||||
ASSERT_EQ(x_times_16_count, 1);
|
||||
|
||||
// Verify execution.
|
||||
ClientSession session(root, SessionOptionsWithInlining());
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(-2));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(branch_index.node()), Input::Initializer(31));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||
@ -114,13 +115,13 @@ Status LowerFunctionalOpsPass::Run(
|
||||
? *keep_lowered_nodes_fetchable_
|
||||
: !HasArgsOrRetvals(*g);
|
||||
|
||||
// Lower all If and While ops that have the `kLowerUsingSwitchMergeAttr` attr
|
||||
// set and inlines all function calls into the graph.
|
||||
// Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
|
||||
// attr set and inline all function calls into the graph.
|
||||
// We start at `i` = 2 to skip the source and sink nodes.
|
||||
// Note that `g->num_node_ids()` may change in the for body if a matching If
|
||||
// or While node is lowered. Since new graph nodes are always added to the
|
||||
// end of the list of nodes it is ensured that nested If/While nodes will be
|
||||
// lowered as well.
|
||||
// Note that `g->num_node_ids()` may change in the for body if a matching If,
|
||||
// Case, While node is lowered. Since new graph nodes are always added to the
|
||||
// end of the list of nodes it is ensured that nested If/Case/While nodes will
|
||||
// be lowered as well.
|
||||
for (int i = 2; i < g->num_node_ids(); ++i) {
|
||||
Node* n = g->FindNodeId(i);
|
||||
if (n == nullptr) continue; // deleted node
|
||||
@ -139,6 +140,9 @@ Status LowerFunctionalOpsPass::Run(
|
||||
if (n->type_string() == "If") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteIfNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||
} else if (n->type_string() == "Case") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteCaseNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||
} else if (n->type_string() == "While") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||
|
@ -14,9 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -89,9 +89,9 @@ class CondBuilder {
|
||||
Node* then_call_node_;
|
||||
Node* else_call_node_;
|
||||
// Merge node that has inputs from [pivot_t, pivot_f] and control edges from
|
||||
// [^then_call_node_, ^else_call_node_]. This node will guarantee that if
|
||||
// then/else branch functions do not have outputs, they still will be executed
|
||||
// for the side effects.
|
||||
// [^then_call_node_, ^else_call_node_]. This node will guarantee that even
|
||||
// when then/else branch functions do not have outputs, they still will be
|
||||
// executed for the side effects.
|
||||
Node* branch_executed_node_;
|
||||
Graph* graph_;
|
||||
const FunctionLibraryDefinition& flib_;
|
||||
|
@ -58,6 +58,7 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
||||
*new std::unordered_map<string, Node::NodeClass>({
|
||||
// Keep in same order as NodeClass values
|
||||
REF_CLASS("Switch", NC_SWITCH),
|
||||
REF_CLASS("_SwitchN", NC_SWITCH),
|
||||
REF_CLASS("Merge", NC_MERGE),
|
||||
REF_CLASS("Enter", NC_ENTER),
|
||||
REF_CLASS("Exit", NC_EXIT),
|
||||
|
@ -150,7 +150,8 @@ bool IsControlFlow(const NodeDef& node) {
|
||||
node.op() == "LoopCond" ||
|
||||
node.op() == "Merge" ||
|
||||
node.op() == "NextIteration" ||
|
||||
node.op() == "Switch";
|
||||
node.op() == "Switch" ||
|
||||
node.op() == "_SwitchN";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -523,7 +524,7 @@ bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
|
||||
|
||||
bool IsSwitch(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Switch" || op == "RefSwitch";
|
||||
return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
|
||||
}
|
||||
|
||||
bool IsSymbolicGradient(const NodeDef& node) {
|
||||
|
@ -1281,7 +1281,7 @@ Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
|
||||
node_map_->AddOutput(const_index->name(), output->name());
|
||||
} else {
|
||||
// This is a control dependency (or an invalid edge since the
|
||||
// merge node has only 2 inputs): preserve them.
|
||||
// merge node has only 2 outputs): preserve them.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||
@ -1157,6 +1158,8 @@ Status InlineFunctionCalls(const GrapplerItem& item,
|
||||
|
||||
if (n->type_string() == "If") {
|
||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
|
||||
} else if (n->type_string() == "Case") {
|
||||
TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), flib_def, false));
|
||||
} else if (n->type_string() == "While") {
|
||||
TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
|
||||
}
|
||||
|
@ -195,6 +195,7 @@ std::set<string> GetOpsFormatAgnostic() {
|
||||
"StridedSlice",
|
||||
"StridedSliceGrad",
|
||||
"Switch",
|
||||
"_SwitchN",
|
||||
"Tile",
|
||||
"TruncateDiv",
|
||||
"TruncateMod",
|
||||
@ -1943,7 +1944,15 @@ class SwitchProcessor : public AgnosticNodeProcessor {
|
||||
: AgnosticNodeProcessor(opt_cxt) {}
|
||||
|
||||
protected:
|
||||
std::set<int> GetOutputPos() const override { return {0, 1}; }
|
||||
std::set<int> GetOutputPos() const override {
|
||||
std::set<int> output_pos;
|
||||
const int num_outs =
|
||||
node_->attr().count("num_outs") ? node_->attr().at("num_outs").i() : 2;
|
||||
for (int i = 0; i < num_outs; i++) {
|
||||
output_pos.insert(i);
|
||||
}
|
||||
return output_pos;
|
||||
}
|
||||
};
|
||||
|
||||
class TileProcessor : public AgnosticNodeProcessor {
|
||||
|
@ -460,7 +460,7 @@ std::vector<int> GetStackPushNodesToConvert(
|
||||
|
||||
const std::unordered_set<string> op_types_to_traverse(
|
||||
{"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
|
||||
"Identity", "RefIdentity"});
|
||||
"_SwitchN", "Identity", "RefIdentity"});
|
||||
const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
|
||||
return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
|
||||
};
|
||||
@ -753,6 +753,9 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
if (!IsSwitch(node)) {
|
||||
continue;
|
||||
}
|
||||
if (node.op() == "_SwitchN") { // _SwitchN not used in loop control flow.
|
||||
continue;
|
||||
}
|
||||
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
|
||||
continue;
|
||||
}
|
||||
@ -798,8 +801,8 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
if (IsMerge(*dead.node)) {
|
||||
const int num_data_inputs = dead.node->attr().at("N").i();
|
||||
if (num_data_inputs > 2) {
|
||||
// This never happens in practice, so we'll just skip these to
|
||||
// simplify the code for now.
|
||||
// This can happen with _SwitchN/Merge (Case lowering). We skip these
|
||||
// to simplify the code for now.
|
||||
found_node_to_preserve = true;
|
||||
break;
|
||||
}
|
||||
@ -877,11 +880,15 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
}
|
||||
// Remove dead data input.
|
||||
const std::set<int>& dead_inputs = itr.second;
|
||||
CHECK_LE(dead_inputs.size(), 1);
|
||||
// (This loop would delete >1 items possibly in the wrong order.)
|
||||
for (int index : dead_inputs) {
|
||||
dead_node->mutable_input()->DeleteSubrange(index, 1);
|
||||
}
|
||||
// Turn Merge into Identity only if we deleted data inputs.
|
||||
// Turn Merge into Identity only if we deleted the other data input.
|
||||
if (!dead_inputs.empty()) {
|
||||
const int num_data_inputs = dead_node->attr().at("N").i();
|
||||
CHECK_EQ(num_data_inputs, dead_inputs.size() + 1);
|
||||
dead_node->set_op("Identity");
|
||||
dead_node->mutable_attr()->erase("N");
|
||||
}
|
||||
|
@ -39,12 +39,30 @@ void SwitchOp::Compute(OpKernelContext* context) {
|
||||
}
|
||||
}
|
||||
|
||||
void SwitchNOp::Compute(OpKernelContext* context) {
|
||||
const Tensor& output_index_t = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_index_t.shape()),
|
||||
errors::InvalidArgument("The second input must be a scalar, "
|
||||
"but it has shape ",
|
||||
output_index_t.shape().DebugString()));
|
||||
int output_index = output_index_t.scalar<int>()();
|
||||
if (output_index < 0 || output_index >= num_outputs()) {
|
||||
output_index = num_outputs() - 1;
|
||||
}
|
||||
context->set_output(output_index, context->input(0));
|
||||
}
|
||||
|
||||
#define REGISTER_CPU_SWITCH(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("pred") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SwitchOp)
|
||||
SwitchOp) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("output_index") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SwitchNOp)
|
||||
|
||||
#define REGISTER_CPU_REF_SWITCH(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
||||
@ -58,7 +76,12 @@ void SwitchOp::Compute(OpKernelContext* context) {
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("pred") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SwitchOp)
|
||||
SwitchOp) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("output_index") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SwitchNOp)
|
||||
|
||||
#define REGISTER_GPU_REF_SWITCH(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
||||
@ -96,7 +119,14 @@ TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||
.HostMemory("output_false") \
|
||||
.HostMemory("output_true") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SwitchOp)
|
||||
SwitchOp) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("data") \
|
||||
.HostMemory("output_index") \
|
||||
.HostMemory("outputs") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SwitchNOp)
|
||||
|
||||
#define REGISTER_GPU_HOST_REF_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
||||
|
@ -46,6 +46,21 @@ class SwitchOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp);
|
||||
};
|
||||
|
||||
// An n-way switch op has two inputs and N outputs. It forwards the value of
|
||||
// Input:0 to the output specified by Input:1. Input:1 is an integer tensor.
|
||||
// Input:0 is forwarded to output:0 if Input:1 is 0, to output:1 if 1, and so
|
||||
// forth. If Input:1 is <0 or >=num_outputs(), Input:0 is forwarded to
|
||||
// output:num_outputs()-1.
|
||||
class SwitchNOp : public OpKernel {
|
||||
public:
|
||||
explicit SwitchNOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
void Compute(OpKernelContext* context) override;
|
||||
bool IsExpensive() override { return false; }
|
||||
~SwitchNOp() override {}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SwitchNOp);
|
||||
};
|
||||
|
||||
// A merge op has n inputs and two outputs. It forwards the value of the
|
||||
// first input that becomes available to its first output, and the
|
||||
// index of the first input to its second output.
|
||||
|
@ -24,6 +24,7 @@ using shape_inference::ShapeHandle;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
namespace {
|
||||
|
||||
Status SwitchShape(InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
@ -39,6 +40,27 @@ Status SwitchShape(InferenceContext* c) {
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SwitchNShape(InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
ShapeHandle out = c->input(0);
|
||||
int num_outs;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_outs", &num_outs));
|
||||
for (int i = 0; i < num_outs; i++) {
|
||||
c->set_output(i, out);
|
||||
}
|
||||
|
||||
// Handle resource shape / dtype.
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
for (int i = 0; i < num_outs; i++) {
|
||||
c->set_output_handle_shapes_and_types(i, *handle_data);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("Switch")
|
||||
@ -58,6 +80,14 @@ REGISTER_OP("RefSwitch")
|
||||
.SetAllowsUninitializedInput()
|
||||
.SetShapeFn(SwitchShape);
|
||||
|
||||
REGISTER_OP("_SwitchN")
|
||||
.Input("data: T")
|
||||
.Input("output_index: int32")
|
||||
.Output("outputs: num_outs * T")
|
||||
.Attr("num_outs: int >= 1")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(SwitchNShape);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("RefSelect")
|
||||
.Input("index: int32")
|
||||
|
@ -49,6 +49,27 @@ TEST(ControlFlowOpsTest, Merge_ShapeFn) {
|
||||
INFER_OK(op, "[2,1];[2,1];[2,1]", "in0;[]");
|
||||
}
|
||||
|
||||
TEST(ControlFlowOpsTest, SwitchN_ShapeFn) {
|
||||
ShapeInferenceTestOp op("_SwitchN");
|
||||
|
||||
int n = 5;
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "_SwitchN")
|
||||
.Input({"d", 0, DT_FLOAT})
|
||||
.Input({"bi", 0, DT_INT32})
|
||||
.Attr("num_outs", n)
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
// Non-scalar output_index.
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[2]");
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1]");
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[?]");
|
||||
// The second input should always be scalar. Outputs are 5x the first input.
|
||||
INFER_OK(op, "?;?", "in0;in0;in0;in0;in0");
|
||||
INFER_OK(op, "[2,?];?", "in0;in0;in0;in0;in0");
|
||||
INFER_OK(op, "[2,?];[]", "in0;in0;in0;in0;in0");
|
||||
INFER_OK(op, "[2,3];[]", "in0;in0;in0;in0;in0");
|
||||
}
|
||||
|
||||
TEST(ControlFlowOpsTest, RefSelect_ShapeFn) {
|
||||
ShapeInferenceTestOp op("RefSelect");
|
||||
|
||||
|
@ -1177,6 +1177,84 @@ class FunctionalOpsCaseTest(test.TestCase):
|
||||
self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default
|
||||
self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("Don't lower for XLA")
|
||||
def testSkipEagerCaseLoweringPreservesNameForFetch(self):
|
||||
for use_gpu in (True, False):
|
||||
def Run(branch, x, fetch_by_name, use_gpu=use_gpu):
|
||||
with ops.Graph().as_default() as g:
|
||||
@function.Defun(dtypes.float32)
|
||||
def two(x):
|
||||
return -1, x * 2
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def three(x):
|
||||
return 0, x * 3
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def four(x):
|
||||
return 1, x * 4
|
||||
|
||||
outputs = gen_functional_ops.case(branch, input=[x],
|
||||
Tout=[dtypes.int32, dtypes.float32],
|
||||
branches=[two, three, four],
|
||||
name="my_case")
|
||||
|
||||
# `outputs` is the list of output tensors of the Case op. We
|
||||
# arbitrarily choose the 0th tensor to get the Case op and set the
|
||||
# lowering attribute on it.
|
||||
outputs[0].op._set_attr("_lower_using_switch_merge",
|
||||
attr_value_pb2.AttrValue(b=True))
|
||||
outputs = array_ops.identity_n(outputs)
|
||||
with self.session(graph=g, use_gpu=use_gpu) as sess:
|
||||
return sess.run("my_case:1" if fetch_by_name else outputs[1])
|
||||
|
||||
self.assertAllEqual(2 * 1., Run(0, 1., False))
|
||||
self.assertAllEqual(2 * 1., Run(0, 1., True))
|
||||
self.assertAllEqual(3 * 7., Run(1, 7., False))
|
||||
self.assertAllEqual(3 * 7., Run(1, 7., True))
|
||||
self.assertAllEqual(4 * -3., Run(2, -3., False))
|
||||
self.assertAllEqual(4 * -3., Run(2, -3., True))
|
||||
self.assertAllEqual(4 * -4., Run(7, -4., False)) # >= N default
|
||||
self.assertAllEqual(4 * -4., Run(7, -4., True)) # >= N default
|
||||
self.assertAllEqual(4 * -5., Run(-1, -5., False)) # <0 default
|
||||
self.assertAllEqual(4 * -5., Run(-1, -5., True)) # <0 default
|
||||
|
||||
@test_util.disable_xla("Don't lower for XLA")
|
||||
def testCaseLowering(self):
|
||||
for use_gpu in (True, False):
|
||||
@eager_function.defun
|
||||
def Run(branch, x):
|
||||
@function.Defun(dtypes.float32)
|
||||
def two(x):
|
||||
return -1, x * 2
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def three(x):
|
||||
return 0, x * 3
|
||||
|
||||
@function.Defun(dtypes.float32)
|
||||
def four(x):
|
||||
return 1, x * 4
|
||||
|
||||
outputs = gen_functional_ops.case(branch, input=[x],
|
||||
Tout=[dtypes.int32, dtypes.float32],
|
||||
branches=[two, three, four])
|
||||
|
||||
# `outputs` is the list of output tensors of the Case op. We
|
||||
# arbitrarily choose the 0th tensor to get the Case op and set the
|
||||
# lowering attribute on it.
|
||||
outputs[0].op._set_attr("_lower_using_switch_merge",
|
||||
attr_value_pb2.AttrValue(b=True))
|
||||
outputs = array_ops.identity_n(outputs)
|
||||
return outputs[1]
|
||||
|
||||
with ops.device(test.gpu_device_name() if use_gpu else "CPU:0"):
|
||||
self.assertAllEqual(2 * 1., self.evaluate(Run(0, 1.)))
|
||||
self.assertAllEqual(3 * 7., self.evaluate(Run(1, 7.)))
|
||||
self.assertAllEqual(4 * -3., self.evaluate(Run(2, -3.)))
|
||||
self.assertAllEqual(4 * -4., self.evaluate(Run(7, -4.))) # >=N default
|
||||
self.assertAllEqual(4 * -5., self.evaluate(Run(-1, -5.))) # <0 default
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -510,7 +510,7 @@ def _make_output_composite_tensors_match(op_type, branch_graphs):
|
||||
raise TypeError(
|
||||
"Cannot reconcile {op_name} {output_idx}-th outputs:\n"
|
||||
" outputs from all branches: {outputs}".format(
|
||||
op_name="tf.cond" if op_type == _COND else "tf.case",
|
||||
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
|
||||
output_idx=output_idx,
|
||||
outputs=branch_outs))
|
||||
|
||||
@ -534,9 +534,9 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
|
||||
for output_idx, branch_outs in enumerate(
|
||||
zip(*branch_outputs_flat_with_composites)):
|
||||
if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
|
||||
raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n"
|
||||
raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
|
||||
" branches returned: {outputs}".format(
|
||||
op_name="tf.cond" if op_type == _COND else "tf.case",
|
||||
op_name="cond" if op_type == _COND else "switch_case",
|
||||
output_idx=output_idx,
|
||||
outputs=branch_outs))
|
||||
if isinstance(branch_outs[0], ops.IndexedSlices):
|
||||
@ -632,7 +632,7 @@ def _check_same_outputs(op_type, graphs):
|
||||
b0_name="true_fn" if op_type == _COND else "branches[0]",
|
||||
bn_name=("false_fn" if op_type == _COND else
|
||||
"branches[{}]".format(branch_idx)),
|
||||
op_name="tf.cond" if op_type == _COND else "tf.case",
|
||||
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
|
||||
b0_out=graphs[0].structured_outputs,
|
||||
bn_out=graphs[branch_idx].structured_outputs,
|
||||
detail=error_detail))
|
||||
@ -817,8 +817,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"):
|
||||
|
||||
@ops.RegisterGradient("Case")
|
||||
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
||||
"""The gradient of a Case op produced (w/ branch_index) by tf.case."""
|
||||
# Get the if operator (this logic handles the case where op is a MockOp)
|
||||
"""The gradient of a Case op produced by tf.switch_case."""
|
||||
# Get the Case operator (this logic handles the case where op is a MockOp)
|
||||
case_op = op.outputs[0].op
|
||||
branch_graphs = _get_func_graphs(case_op)
|
||||
assert branch_graphs
|
||||
@ -892,7 +892,7 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
||||
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
|
||||
|
||||
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
|
||||
branches_grad_inputs)
|
||||
branches_grad_inputs, name="gradient")
|
||||
|
||||
# The predicate has no gradient.
|
||||
return [None] + outputs
|
||||
@ -935,14 +935,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
||||
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||
name=name)
|
||||
|
||||
# TODO(b/110167197) this requires Case to have at least 1 output
|
||||
# TODO(b/110167197): this requires Case to have at least 1 output
|
||||
case_op = tensors[0].op
|
||||
# TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode.
|
||||
# util.maybe_set_lowering_attr(case_op)
|
||||
util.maybe_set_lowering_attr(case_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||
|
||||
# Return identities for each output of the Case op, rather than the output of
|
||||
# the Case op directly. This makes pruning work if the output of select_case()
|
||||
# the Case op directly. This makes pruning work if the output of switch_case()
|
||||
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
||||
# outputs, which if fetched will cause all ops in the taken branch to be run
|
||||
# (since it takes all merge ops as input). After lowering, each output
|
||||
|
@ -19,11 +19,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -39,7 +42,9 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -933,7 +938,10 @@ class DataTypesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class IndexedCaseTest(test_util.TensorFlowTestCase):
|
||||
class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def make_name(self):
|
||||
return self.id().split(".")[-1].replace("(", "_").replace(")", "")
|
||||
|
||||
def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self):
|
||||
nbranches = 5
|
||||
@ -947,53 +955,55 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
|
||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||
self.assertEqual(bi * 10, self.evaluate(case_out))
|
||||
|
||||
def testCase(self):
|
||||
@parameterized.parameters((0,), (2,), (3,))
|
||||
def testCase(self, bi):
|
||||
nbranches = 5
|
||||
|
||||
def make_func(bi):
|
||||
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||
|
||||
branches = [(i, make_func(i)) for i in range(nbranches)]
|
||||
for bi in 0, 2, 3:
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||
self.assertEqual(bi * 10., self.evaluate(case_out))
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
case_out = control_flow_ops.switch_case(
|
||||
branch_index, branches, name=self.make_name())
|
||||
self.assertEqual(bi * 10., self.evaluate(case_out))
|
||||
|
||||
def testCase_withDefault(self):
|
||||
@parameterized.parameters((-1,), (2,), (4,), (5,), (6,))
|
||||
def testCase_withDefault(self, bi):
|
||||
nbranches = 5
|
||||
|
||||
def make_func(bi):
|
||||
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||
|
||||
branches = [(i, make_func(i)) for i in range(nbranches)]
|
||||
for bi in -1, 2, nbranches:
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
case_out = control_flow_ops.switch_case(
|
||||
branch_index, branches, default=make_func(6))
|
||||
if bi < 0 or bi >= nbranches:
|
||||
expected = 60.
|
||||
else:
|
||||
expected = bi * 10.
|
||||
self.assertEqual(expected, self.evaluate(case_out))
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
case_out = control_flow_ops.switch_case(
|
||||
branch_index, branches, default=make_func(6), name=self.make_name())
|
||||
if bi < 0 or bi >= nbranches:
|
||||
expected = 60.
|
||||
else:
|
||||
expected = bi * 10.
|
||||
self.assertEqual(expected, self.evaluate(case_out))
|
||||
|
||||
def testCase_dictWithDefault(self):
|
||||
@parameterized.parameters((-1,), (0,), (3,), (5,))
|
||||
def testCase_dictWithDefault(self, bi):
|
||||
nbranches = 5
|
||||
|
||||
def make_func(bi):
|
||||
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||
|
||||
branches = {i: make_func(i) for i in range(nbranches)}
|
||||
for bi in -1, 0, 3, nbranches:
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
case_out = control_flow_ops.switch_case(
|
||||
branch_index, branches, default=make_func(6))
|
||||
if bi < 0 or bi >= nbranches:
|
||||
expected = 60.
|
||||
else:
|
||||
expected = bi * 10.
|
||||
self.assertEqual(expected, self.evaluate(case_out))
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
case_out = control_flow_ops.switch_case(
|
||||
branch_index, branches, default=make_func(6), name=self.make_name())
|
||||
if bi < 0 or bi >= nbranches:
|
||||
expected = 60.
|
||||
else:
|
||||
expected = bi * 10.
|
||||
self.assertEqual(expected, self.evaluate(case_out))
|
||||
|
||||
def testCase_gradient(self):
|
||||
@parameterized.parameters((-1,), (1,), (4,), (5,))
|
||||
def testCase_gradient(self, bi):
|
||||
nbranches = 5
|
||||
inputs = [
|
||||
array_ops.constant(float(bi), name="br{}_in".format(bi))
|
||||
@ -1005,22 +1015,22 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
|
||||
|
||||
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
||||
|
||||
for bi in -1, 1, 4, nbranches:
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
with backprop.GradientTape() as tape:
|
||||
for x in inputs:
|
||||
tape.watch(x)
|
||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||
out_grad = 3.
|
||||
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
|
||||
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
|
||||
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
|
||||
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||
for expected, actual in zip(expected_grads, actual_grads):
|
||||
self.assertEqual(expected, self.evaluate(actual))
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
with backprop.GradientTape() as tape:
|
||||
for x in inputs:
|
||||
tape.watch(x)
|
||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||
out_grad = 3.
|
||||
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
|
||||
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
|
||||
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
|
||||
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||
for expected, actual in zip(expected_grads, actual_grads):
|
||||
self.assertEqual(expected, self.evaluate(actual))
|
||||
|
||||
def testCase_gradient_diffShapedIntermediates(self):
|
||||
@parameterized.parameters((-2,), (2,), (5,))
|
||||
def testCase_gradient_diffShapedIntermediates(self, bi):
|
||||
nbranches = 5
|
||||
inputs = [
|
||||
array_ops.constant(
|
||||
@ -1038,34 +1048,105 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
|
||||
|
||||
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
||||
|
||||
for bi in -1, 2, nbranches:
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
with backprop.GradientTape() as tape:
|
||||
for x in inputs:
|
||||
tape.watch(x)
|
||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||
out_grad = 3.
|
||||
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
|
||||
expected_grads = []
|
||||
for input_idx in range(nbranches):
|
||||
if used_bi == input_idx:
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(inputs[used_bi])
|
||||
y = make_func(used_bi)()
|
||||
expected_grads.append(
|
||||
self.evaluate(
|
||||
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
|
||||
else:
|
||||
expected_grads.append(None if context.executing_eagerly() else [0.] *
|
||||
(input_idx + 1))
|
||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||
with backprop.GradientTape() as tape:
|
||||
for x in inputs:
|
||||
tape.watch(x)
|
||||
case_out = control_flow_ops.switch_case(
|
||||
branch_index, branches, name=self.make_name())
|
||||
out_grad = 3.
|
||||
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
|
||||
expected_grads = []
|
||||
for input_idx in range(nbranches):
|
||||
if used_bi == input_idx:
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(inputs[used_bi])
|
||||
y = make_func(used_bi)()
|
||||
expected_grads.append(
|
||||
self.evaluate(
|
||||
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
|
||||
else:
|
||||
expected_grads.append(None if context.executing_eagerly() else [0.] *
|
||||
(input_idx + 1))
|
||||
|
||||
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||
for expected, actual in zip(expected_grads, actual_grads):
|
||||
if expected is None:
|
||||
self.assertIsNone(actual)
|
||||
else:
|
||||
self.assertAllEqual(expected, self.evaluate(actual))
|
||||
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||
for expected, actual in zip(expected_grads, actual_grads):
|
||||
if expected is None:
|
||||
self.assertIsNone(actual)
|
||||
else:
|
||||
self.assertAllEqual(expected, self.evaluate(actual))
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.disable_xla("Wants RunMetadata")
|
||||
def testParallelExecution(self):
|
||||
"""Verify disjoint branches across while iterations are run in parallel."""
|
||||
with ops.Graph().as_default() as g:
|
||||
nbranches = 7
|
||||
matrices = array_ops.unstack( # Ensure all are ready before while.
|
||||
array_ops.matrix_diag(
|
||||
random_ops.random_uniform([nbranches, 8, 512]) + 1e-3))
|
||||
|
||||
def make_branch(i, mat, name):
|
||||
def branch_fn():
|
||||
next_i = i + 1
|
||||
with ops.device("gpu:0"):
|
||||
return next_i, math_ops.reduce_sum(
|
||||
linalg_ops.cholesky(mat, name=name + "_Cholesky"))
|
||||
return branch_fn
|
||||
|
||||
def make_branches(i):
|
||||
return [make_branch(i, matrices[bi], "br{}".format(bi))
|
||||
for bi in range(nbranches)]
|
||||
|
||||
def cond(i, _):
|
||||
return i < nbranches
|
||||
|
||||
def body(i, result):
|
||||
with ops.device("cpu:0"):
|
||||
next_i, branch_out = control_flow_ops.switch_case(i, make_branches(i))
|
||||
return next_i, result + branch_out
|
||||
|
||||
_, result = control_flow_ops.while_loop(cond, body, [0, 0.])
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
run_options = config_pb2.RunOptions(
|
||||
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
config = config_pb2.ConfigProto(
|
||||
allow_soft_placement=False, log_device_placement=True)
|
||||
|
||||
with session.Session(config=config, graph=g) as sess:
|
||||
_ = sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||
chol_node_stats = []
|
||||
for dev_stats in run_metadata.step_stats.dev_stats:
|
||||
for node_stats in dev_stats.node_stats:
|
||||
if (node_stats.node_name.endswith("Cholesky") and
|
||||
node_stats.all_start_nanos > 0):
|
||||
chol_node_stats.append(node_stats)
|
||||
|
||||
self.assertLen(chol_node_stats, nbranches)
|
||||
|
||||
chol_node_stats = sorted(chol_node_stats, key=lambda stats: stats.node_name)
|
||||
op_start_nanos = [
|
||||
stats.all_start_nanos for stats in chol_node_stats
|
||||
]
|
||||
op_end_nanos = [
|
||||
stats.all_start_nanos + stats.op_end_rel_nanos
|
||||
for stats in chol_node_stats
|
||||
]
|
||||
|
||||
def overlap(range1, range2):
|
||||
s1, e1 = range1
|
||||
s2, e2 = range2
|
||||
if s1 < s2:
|
||||
return 0 if s2 > e1 else e1 - s2
|
||||
return 0 if s1 > e2 else e2 - s1
|
||||
|
||||
timespans = list(zip(op_start_nanos, op_end_nanos))
|
||||
overlaps_chol0 = [overlap(timespans[0], r2) for r2 in timespans[1:]]
|
||||
# There are nbranches-1 overlaps, sometimes all nonzero, but we
|
||||
# conservatively check for at least one here, to avoid test flakiness.
|
||||
self.assertGreater(np.count_nonzero(overlaps_chol0), 0)
|
||||
|
||||
def testCase_validateIndicesContiguous(self):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user