Adds a lowering from Case to _SwitchN+Merge.
Introduces a new n-way _SwitchN op+kernel. I audited usages of the grappler and graph variants of IsSwitch and IsMerge, and believe the corrections in this CL are correct. PiperOrigin-RevId: 250803634
This commit is contained in:
parent
586ff7b1fa
commit
10ed2f7bb5
@ -846,24 +846,42 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
|
|||||||
const Edge* pred_edge;
|
const Edge* pred_edge;
|
||||||
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
|
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
|
||||||
|
|
||||||
Predicate* true_switch;
|
if (n->num_outputs() == 2) {
|
||||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
Predicate* true_switch;
|
||||||
pred_edge->src(), pred_edge->src_output(),
|
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||||
/*must_be_true=*/true, &true_switch));
|
pred_edge->src(), pred_edge->src_output(),
|
||||||
|
/*must_be_true=*/true, &true_switch));
|
||||||
|
|
||||||
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
|
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
|
||||||
|
|
||||||
// Output 0 is alive iff all inputs are alive and the condition is false.
|
// Output 0 is alive iff all inputs are alive and the condition is false.
|
||||||
input_preds.push_back(false_switch);
|
input_preds.push_back(false_switch);
|
||||||
SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
|
SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
|
||||||
should_revisit);
|
should_revisit);
|
||||||
input_preds.pop_back();
|
input_preds.pop_back();
|
||||||
|
|
||||||
// Output 1 is alive iff all inputs are alive and the condition is true.
|
// Output 1 is alive iff all inputs are alive and the condition is true.
|
||||||
input_preds.push_back(true_switch);
|
input_preds.push_back(true_switch);
|
||||||
SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
|
SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
|
||||||
should_revisit);
|
should_revisit);
|
||||||
input_preds.pop_back();
|
input_preds.pop_back();
|
||||||
|
} else { // N-way switch case. Exactly one of N branches is alive.
|
||||||
|
Predicate* branch_pred;
|
||||||
|
for (int i = 0; i < n->num_outputs() - 1; i++) {
|
||||||
|
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||||
|
n, i, /*must_be_true=*/false, &branch_pred));
|
||||||
|
input_preds.push_back(branch_pred);
|
||||||
|
SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds),
|
||||||
|
should_revisit);
|
||||||
|
input_preds.pop_back();
|
||||||
|
input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred));
|
||||||
|
}
|
||||||
|
// The default (last) branch does not need its own symbol, is simply the
|
||||||
|
// nor of all other branches.
|
||||||
|
SetPredicate(n, n->num_outputs() - 1,
|
||||||
|
predicate_factory_.MakeAndPredicate(input_preds),
|
||||||
|
should_revisit);
|
||||||
|
}
|
||||||
|
|
||||||
// Control is alive iff all inputs are alive.
|
// Control is alive iff all inputs are alive.
|
||||||
SetPredicate(n, Graph::kControlSlot,
|
SetPredicate(n, Graph::kControlSlot,
|
||||||
|
@ -3171,6 +3171,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
|
|||||||
"common_runtime/local_device.h",
|
"common_runtime/local_device.h",
|
||||||
"common_runtime/lower_function_call_op.h",
|
"common_runtime/lower_function_call_op.h",
|
||||||
"common_runtime/lower_if_op.h",
|
"common_runtime/lower_if_op.h",
|
||||||
|
"common_runtime/lower_case_op.h",
|
||||||
"common_runtime/lower_functional_ops.h",
|
"common_runtime/lower_functional_ops.h",
|
||||||
"common_runtime/lower_while_op.h",
|
"common_runtime/lower_while_op.h",
|
||||||
"common_runtime/memory_types.h",
|
"common_runtime/memory_types.h",
|
||||||
@ -3232,6 +3233,7 @@ tf_cuda_library(
|
|||||||
"common_runtime/inspecting_placer.h",
|
"common_runtime/inspecting_placer.h",
|
||||||
"common_runtime/isolate_placer_inspection_required_ops_pass.cc",
|
"common_runtime/isolate_placer_inspection_required_ops_pass.cc",
|
||||||
"common_runtime/local_device.cc",
|
"common_runtime/local_device.cc",
|
||||||
|
"common_runtime/lower_case_op.cc",
|
||||||
"common_runtime/lower_function_call_op.cc",
|
"common_runtime/lower_function_call_op.cc",
|
||||||
"common_runtime/lower_functional_ops.cc",
|
"common_runtime/lower_functional_ops.cc",
|
||||||
"common_runtime/lower_if_op.cc",
|
"common_runtime/lower_if_op.cc",
|
||||||
@ -5326,6 +5328,30 @@ tf_cc_tests(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_tests(
|
||||||
|
name = "common_runtime_lower_case_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["common_runtime/lower_case_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":all_kernels",
|
||||||
|
":core_cpu",
|
||||||
|
":core_cpu_internal",
|
||||||
|
":direct_session",
|
||||||
|
":framework",
|
||||||
|
":framework_internal",
|
||||||
|
":lib",
|
||||||
|
":test",
|
||||||
|
":test_main",
|
||||||
|
":testlib",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:cc_ops_internal",
|
||||||
|
"//tensorflow/cc:client_session",
|
||||||
|
"//tensorflow/cc:function_ops",
|
||||||
|
"//tensorflow/cc:ops",
|
||||||
|
"//tensorflow/cc:resource_variable_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_tests(
|
tf_cc_tests(
|
||||||
name = "common_runtime_lower_while_op_test",
|
name = "common_runtime_lower_while_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
300
tensorflow/core/common_runtime/lower_case_op.cc
Normal file
300
tensorflow/core/common_runtime/lower_case_op.cc
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using NodeOut = NodeBuilder::NodeOut;
|
||||||
|
|
||||||
|
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
|
||||||
|
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
|
||||||
|
|
||||||
|
// Convenience builder to make it easy to construct a case with a single
|
||||||
|
// function call in each branch. This first converts the Case node
|
||||||
|
// into switches (for inputs) and merges (for outputs) around a function call
|
||||||
|
// per branch.
|
||||||
|
class CaseBuilder {
|
||||||
|
public:
|
||||||
|
// Create a CaseBuilder to create the lowered form of `case` with branch
|
||||||
|
// functions identified by `branch_fn_names` in the `graph`. The functions
|
||||||
|
// should be available in `flib`.
|
||||||
|
CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names,
|
||||||
|
const FunctionLibraryDefinition& flib, bool keep_node_fetchable,
|
||||||
|
Graph* graph);
|
||||||
|
|
||||||
|
// Constructs the basic conditional control flow using switch and merge nodes.
|
||||||
|
Status CreatePivotNodes();
|
||||||
|
|
||||||
|
// Adds the inputs from the if node to the merge nodes of the lowered if.
|
||||||
|
Status AddInputs();
|
||||||
|
|
||||||
|
// Adds the outputs from the if node to the merge nodes of the lowered if.
|
||||||
|
// Note: no inputs can be added once outputs are added as the then and else
|
||||||
|
// nodes are finalized while adding outputs.
|
||||||
|
Status AddOutputs();
|
||||||
|
|
||||||
|
// Builds an identity node with the same outputs as Case.
|
||||||
|
Status BuildLoweredCaseOutput();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns unique name containing the name of the Case op being rewritten
|
||||||
|
// (name_), infix and a suffix to ensure it is unique within the graph.
|
||||||
|
string NewName(const string& infix);
|
||||||
|
|
||||||
|
// Adds input to both the then and else nodes from src:src_output.
|
||||||
|
Status AddInput(Node* src, int src_output);
|
||||||
|
|
||||||
|
// The merged outputs of the then and else nodes.
|
||||||
|
std::vector<NodeOut> outputs_;
|
||||||
|
|
||||||
|
// The node that dominates all execution of the then and else body nodes.
|
||||||
|
Node* control_predecessor_;
|
||||||
|
// The original Case op.
|
||||||
|
Node* case_op_;
|
||||||
|
// The node with the same name as the original Case op:
|
||||||
|
// (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
|
||||||
|
// and if the original Case op had non-zero data outputs.
|
||||||
|
// (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
|
||||||
|
Node* lowered_case_output_;
|
||||||
|
// The branch selector of the case.
|
||||||
|
OutputTensor branch_index_;
|
||||||
|
int num_branches_;
|
||||||
|
// Nodes corresponding to pivot branch of branch_index _SwitchN, which is
|
||||||
|
// the pivot node that dominates all nodes in the i'th branch.
|
||||||
|
std::vector<Node*> pivots_;
|
||||||
|
std::vector<Node*> call_nodes_;
|
||||||
|
// Merge node that has inputs from each of pivots_ and control edges from
|
||||||
|
// [^call_node for call_node in call_nodes_]. This node will guarantee that
|
||||||
|
// even when branch functions do not have outputs, they still will be executed
|
||||||
|
// for the side effects.
|
||||||
|
Node* branch_executed_node_;
|
||||||
|
Graph* graph_;
|
||||||
|
const FunctionLibraryDefinition& flib_;
|
||||||
|
string name_;
|
||||||
|
bool keep_node_fetchable_;
|
||||||
|
|
||||||
|
NodeDebugInfo debug_info_;
|
||||||
|
std::vector<NodeBuilder> branch_call_builders_;
|
||||||
|
};
|
||||||
|
|
||||||
|
CaseBuilder::CaseBuilder(Node* case_op,
|
||||||
|
const std::vector<string>& branch_fn_names,
|
||||||
|
const FunctionLibraryDefinition& flib,
|
||||||
|
bool keep_node_fetchable, Graph* graph)
|
||||||
|
: case_op_(case_op),
|
||||||
|
num_branches_(branch_fn_names.size()),
|
||||||
|
graph_(graph),
|
||||||
|
flib_(flib),
|
||||||
|
name_(case_op->name()),
|
||||||
|
keep_node_fetchable_(keep_node_fetchable),
|
||||||
|
debug_info_(*case_op_) {
|
||||||
|
branch_call_builders_.reserve(num_branches_);
|
||||||
|
for (int b = 0; b < num_branches_; b++) {
|
||||||
|
branch_call_builders_.emplace_back(NewName(strings::StrCat("branch", b)),
|
||||||
|
branch_fn_names[b], graph->op_registry(),
|
||||||
|
&debug_info_);
|
||||||
|
branch_call_builders_[b].Device(case_op_->requested_device());
|
||||||
|
branch_call_builders_[b].Attr(kLowerAsMultiDeviceFunctionAttr, true);
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CaseBuilder::CreatePivotNodes() {
|
||||||
|
// Construct the basic case body (consisting of feeding in the val to
|
||||||
|
// create pivot nodes).
|
||||||
|
Node* branch_index;
|
||||||
|
TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_index"), "_SwitchN",
|
||||||
|
graph_->op_registry(), &debug_info_)
|
||||||
|
.Input(NodeOut(branch_index_))
|
||||||
|
.Input(NodeOut(branch_index_))
|
||||||
|
.Attr("num_outs", num_branches_)
|
||||||
|
.Device(case_op_->requested_device())
|
||||||
|
.Finalize(graph_, &branch_index));
|
||||||
|
control_predecessor_ = branch_index;
|
||||||
|
pivots_.resize(num_branches_, nullptr);
|
||||||
|
for (int b = 0; b < num_branches_; b++) {
|
||||||
|
TF_RETURN_IF_ERROR(NodeBuilder(NewName(strings::StrCat("pivot_", b)),
|
||||||
|
"Identity", graph_->op_registry(),
|
||||||
|
&debug_info_)
|
||||||
|
.Input(branch_index, b)
|
||||||
|
.Device(case_op_->requested_device())
|
||||||
|
.Finalize(graph_, &pivots_[b]));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
string CaseBuilder::NewName(const string& infix) {
|
||||||
|
return graph_->NewName(strings::StrCat(name_, "/", infix));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CaseBuilder::AddInput(Node* src, int src_output) {
|
||||||
|
Node* input;
|
||||||
|
NodeDebugInfo debug_info(*src);
|
||||||
|
// Colocate the Switch node with the `src` node.
|
||||||
|
//
|
||||||
|
// This is to avoid unnecessary Host<->Device copies between src and the
|
||||||
|
// _SwitchN node. This aligns with the implementation of legacy tf.cond in
|
||||||
|
// control_flow_ops.py. The legacy impl colocates the Switch with the
|
||||||
|
// input tensor which resets the device stack and forces the Switch to have
|
||||||
|
// the same device as the input node (if set) and sets the colocation _class
|
||||||
|
// attr. It also ignores the existing colocation constraints on the input node
|
||||||
|
// using colocate_with(ignore_existing=True).
|
||||||
|
TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "_SwitchN",
|
||||||
|
graph_->op_registry(), &debug_info)
|
||||||
|
.Input(src, src_output)
|
||||||
|
.Input(branch_index_)
|
||||||
|
.Device(src->requested_device())
|
||||||
|
.Attr("_class", {src->name()})
|
||||||
|
.Attr("num_outs", num_branches_)
|
||||||
|
.Finalize(graph_, &input));
|
||||||
|
for (int b = 0; b < num_branches_; b++) {
|
||||||
|
branch_call_builders_[b].Input(input, b);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CaseBuilder::AddInputs() {
|
||||||
|
// Add input data edges.
|
||||||
|
std::vector<const Edge*> edges;
|
||||||
|
TF_RETURN_IF_ERROR(case_op_->input_edges(&edges));
|
||||||
|
// Start at index 1 as the first input is the branch index.
|
||||||
|
for (int i = 1; i < edges.size(); ++i) {
|
||||||
|
const Edge* e = edges[i];
|
||||||
|
TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
|
||||||
|
}
|
||||||
|
// Add input control edges.
|
||||||
|
for (const Edge* e : case_op_->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
graph_->AddControlEdge(e->src(), control_predecessor_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CaseBuilder::AddOutputs() {
|
||||||
|
// Construct the call nodes for each branch.
|
||||||
|
call_nodes_.resize(num_branches_, nullptr);
|
||||||
|
for (int b = 0; b < num_branches_; b++) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
branch_call_builders_[b].Finalize(graph_, &call_nodes_[b]));
|
||||||
|
graph_->AddControlEdge(pivots_[b], call_nodes_[b]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the outputs from the N branches (all branches have matching outputs).
|
||||||
|
const int num_outputs = call_nodes_[0]->num_outputs();
|
||||||
|
std::vector<Node*> merges(num_outputs);
|
||||||
|
outputs_.resize(merges.size());
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
std::vector<NodeOut> merge_input;
|
||||||
|
merge_input.reserve(num_branches_);
|
||||||
|
for (int j = 0; j < num_branches_; j++) {
|
||||||
|
merge_input.emplace_back(call_nodes_[j], i);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(NodeBuilder(NewName("merge"), "Merge",
|
||||||
|
graph_->op_registry(), &debug_info_)
|
||||||
|
.Input(merge_input)
|
||||||
|
.Device(case_op_->requested_device())
|
||||||
|
.Finalize(graph_, &merges[i]));
|
||||||
|
outputs_[i] = NodeOut(merges[i], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a Merge node that will be used as a control dependency source for the
|
||||||
|
// lowered output node. This Merge node will guarantee that lowered else/then
|
||||||
|
// function calls will be executed even if they do not have data outputs.
|
||||||
|
//
|
||||||
|
// Furthermore it will guarantee that all function side effects will be
|
||||||
|
// executed, if the function will be inlined into the graph. Having data
|
||||||
|
// outputs is not enough, because they might become unused after inlining.
|
||||||
|
//
|
||||||
|
// We will use this node to rewrite outgoing control edges from lowered 'Case'
|
||||||
|
// node. All data edges will read tensors directly from Merge nodes.
|
||||||
|
std::vector<NodeOut> pivots(num_branches_);
|
||||||
|
for (int j = 0; j < num_branches_; j++) {
|
||||||
|
pivots[j] = NodeOut(pivots_[j]);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge",
|
||||||
|
graph_->op_registry(), &debug_info_)
|
||||||
|
.Input(pivots)
|
||||||
|
.ControlInputs(call_nodes_)
|
||||||
|
.Device(case_op_->requested_device())
|
||||||
|
.Finalize(graph_, &branch_executed_node_));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(BuildLoweredCaseOutput());
|
||||||
|
|
||||||
|
// Add outputs.
|
||||||
|
for (const Edge* e : case_op_->out_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
graph_->AddControlEdge(branch_executed_node_, e->dst());
|
||||||
|
} else {
|
||||||
|
// Feed the outputs directly from the merge nodes so that downstream ops
|
||||||
|
// can start before all the outputs have been computed.
|
||||||
|
graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CaseBuilder::BuildLoweredCaseOutput() {
|
||||||
|
// If outputs are empty, it means that we might have only output control
|
||||||
|
// edges (already connected to the `branch_executed_node`). Furthermore it's
|
||||||
|
// illegal to have an IdentityN with empty inputs.
|
||||||
|
//
|
||||||
|
// We still must keep lowered Case node as a valid source of control edges,
|
||||||
|
// because it might be a part of function control output set.
|
||||||
|
NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
|
||||||
|
? NodeBuilder(name_, "IdentityN").Input(outputs_)
|
||||||
|
: NodeBuilder(name_, "NoOp");
|
||||||
|
return builder.Device(case_op_->requested_device())
|
||||||
|
.ControlInput(branch_executed_node_)
|
||||||
|
.Finalize(graph_, &lowered_case_output_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
|
||||||
|
bool keep_node_fetchable) {
|
||||||
|
VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable
|
||||||
|
<< "): " << SummarizeNode(*n);
|
||||||
|
const AttrValue* branches_attr = n->attrs().Find("branches");
|
||||||
|
if (branches_attr == nullptr) {
|
||||||
|
return errors::InvalidArgument("branch functions missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_branches = branches_attr->list().func_size();
|
||||||
|
std::vector<string> branch_fn_names;
|
||||||
|
branch_fn_names.reserve(num_branches);
|
||||||
|
for (int b = 0; b < num_branches; b++) {
|
||||||
|
branch_fn_names.emplace_back(branches_attr->list().func(b).name());
|
||||||
|
}
|
||||||
|
CaseBuilder cb(n, branch_fn_names, flib, keep_node_fetchable, g);
|
||||||
|
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
|
||||||
|
TF_RETURN_IF_ERROR(cb.AddInputs());
|
||||||
|
TF_RETURN_IF_ERROR(cb.AddOutputs());
|
||||||
|
g->RemoveNode(n);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
30
tensorflow/core/common_runtime/lower_case_op.h
Normal file
30
tensorflow/core/common_runtime/lower_case_op.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes.
|
||||||
|
Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
|
||||||
|
bool keep_node_fetchable);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_
|
441
tensorflow/core/common_runtime/lower_case_op_test.cc
Normal file
441
tensorflow/core/common_runtime/lower_case_op_test.cc
Normal file
@ -0,0 +1,441 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/client/client_session.h"
|
||||||
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||||
|
#include "tensorflow/cc/ops/function_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||||
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
|
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
|
||||||
|
AttrValue attr;
|
||||||
|
for (const char* name : names) {
|
||||||
|
attr.mutable_list()->add_func()->set_name(name);
|
||||||
|
}
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SessionOptions SessionOptionsWithInlining() {
|
||||||
|
SessionOptions session_options;
|
||||||
|
session_options.config.mutable_graph_options()
|
||||||
|
->mutable_optimizer_options()
|
||||||
|
->set_do_function_inlining(true);
|
||||||
|
return session_options;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Rewrite(std::unique_ptr<Graph>* graph) {
|
||||||
|
FunctionLibraryDefinition flib_def((*graph)->flib_def());
|
||||||
|
GraphOptimizationPassOptions opt_options;
|
||||||
|
SessionOptions session_options = SessionOptionsWithInlining();
|
||||||
|
opt_options.session_options = &session_options;
|
||||||
|
opt_options.graph = graph;
|
||||||
|
opt_options.flib_def = &flib_def;
|
||||||
|
LowerFunctionalOpsPass pass;
|
||||||
|
return pass.Run(opt_options);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LowerCaseOpTest, Simple) {
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
|
||||||
|
// Add test functions for then and else branch.
|
||||||
|
FunctionDefLibrary f_lib_proto;
|
||||||
|
*(f_lib_proto.add_function()) = test::function::XTimesTwo();
|
||||||
|
*(f_lib_proto.add_function()) = test::function::XTimesFour();
|
||||||
|
*(f_lib_proto.add_function()) = test::function::XTimes16();
|
||||||
|
|
||||||
|
// Construct simple conditional that switches on `pred` and operates only on
|
||||||
|
// single input `A`.
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||||
|
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||||
|
auto branch_index =
|
||||||
|
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
|
||||||
|
Node* written_if;
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder("case", "Case", &root.graph()->flib_def())
|
||||||
|
.Input(branch_index.node())
|
||||||
|
.Input(inputs)
|
||||||
|
.Attr("branches",
|
||||||
|
FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"}))
|
||||||
|
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||||
|
.Attr("Tout", {DT_INT32})
|
||||||
|
.Finalize(root.graph(), &written_if));
|
||||||
|
TF_ASSERT_OK(root.DoShapeInference(written_if));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// The input graph has no switch or merge nodes.
|
||||||
|
int node_called_case_count = 0;
|
||||||
|
for (const auto* op : graph->op_nodes()) {
|
||||||
|
ASSERT_FALSE(op->IsSwitch());
|
||||||
|
ASSERT_FALSE(op->IsMerge());
|
||||||
|
if (op->name() == "case") {
|
||||||
|
++node_called_case_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(node_called_case_count, 1);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(Rewrite(&graph));
|
||||||
|
|
||||||
|
// Verify the resultant graph has switch and merge nodes, and a node called
|
||||||
|
// `if` (but not If nodes).
|
||||||
|
int switch_count = 0;
|
||||||
|
int merge_count = 0;
|
||||||
|
node_called_case_count = 0;
|
||||||
|
for (const auto* op : graph->op_nodes()) {
|
||||||
|
if (op->IsSwitch()) {
|
||||||
|
++switch_count;
|
||||||
|
}
|
||||||
|
if (op->IsMerge()) {
|
||||||
|
++merge_count;
|
||||||
|
}
|
||||||
|
ASSERT_NE(op->type_string(), "Case");
|
||||||
|
if (op->name() == "case") {
|
||||||
|
++node_called_case_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// One switch for predicate and one for input (A).
|
||||||
|
ASSERT_EQ(switch_count, 2);
|
||||||
|
// One merge for the single output value of then and else, and one more merge
|
||||||
|
// to enforce then and else function call execution (`branch_executed` node).
|
||||||
|
ASSERT_EQ(merge_count, 2);
|
||||||
|
ASSERT_EQ(node_called_case_count, 1);
|
||||||
|
|
||||||
|
// Verify execution.
|
||||||
|
ClientSession session(root, SessionOptionsWithInlining());
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(-1));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(20));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LowerCaseOpTest, BranchFunctionsWithoutOutputs) {
|
||||||
|
using ::tensorflow::test::function::GDef;
|
||||||
|
using ::tensorflow::test::function::NDef;
|
||||||
|
using FDH = ::tensorflow::FunctionDefHelper;
|
||||||
|
|
||||||
|
// Wrap AssignAddVariable + Const into a function.
|
||||||
|
const auto assign_add = [](const string& fn_name, int v) {
|
||||||
|
const Tensor tensor = test::AsScalar<int32>(v);
|
||||||
|
return FDH::Create(
|
||||||
|
fn_name, {"v: resource"}, {}, {},
|
||||||
|
{
|
||||||
|
{{"c"}, "Const", {}, {{"value", tensor}, {"dtype", DT_INT32}}},
|
||||||
|
{{"upd"},
|
||||||
|
"AssignAddVariableOp",
|
||||||
|
{"v", "c:output"},
|
||||||
|
{{"dtype", DT_INT32}}},
|
||||||
|
},
|
||||||
|
/*ret_def=*/{},
|
||||||
|
/*control_ret_def=*/{{"side_effects", "upd"}});
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
// Add test functions for then and else branch.
|
||||||
|
FunctionDefLibrary f_lib_proto;
|
||||||
|
*(f_lib_proto.add_function()) = assign_add("AddOne", 1);
|
||||||
|
*(f_lib_proto.add_function()) = assign_add("AddTwo", 2);
|
||||||
|
*(f_lib_proto.add_function()) = assign_add("AddTen", 10);
|
||||||
|
|
||||||
|
// Construct a graph to represent following program:
|
||||||
|
//
|
||||||
|
// (branch_index: int32, initial_val: int32) -> (int32) {
|
||||||
|
// var = Variable(initial_value)
|
||||||
|
// switch (branch_index) {
|
||||||
|
// case 0:
|
||||||
|
// var += 1; break; # AddOne function call
|
||||||
|
// case 1:
|
||||||
|
// var += 2; break; # AddTwo function call
|
||||||
|
// default:
|
||||||
|
// var += 10; break; # AddTen function call
|
||||||
|
// }
|
||||||
|
// return var
|
||||||
|
// }
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||||
|
|
||||||
|
auto branch_index =
|
||||||
|
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
|
||||||
|
auto initial_val = ops::Placeholder(root.WithOpName("initial_val"), DT_INT32);
|
||||||
|
|
||||||
|
auto var = ops::VarHandleOp(root.WithOpName("var"), DT_INT32, {});
|
||||||
|
auto init = ops::AssignVariableOp(root.WithOpName("init"), var, initial_val);
|
||||||
|
|
||||||
|
Node* case_node;
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(var.node())});
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder("case", "Case", &root.graph()->flib_def())
|
||||||
|
.Input(branch_index.node())
|
||||||
|
.Input(inputs)
|
||||||
|
.ControlInput(init.operation.node())
|
||||||
|
.Attr("branches", FuncListAttr({"AddOne", "AddTwo", "AddTen"}))
|
||||||
|
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||||
|
.Attr("Tout", DataTypeSlice{})
|
||||||
|
.Finalize(root.graph(), &case_node));
|
||||||
|
|
||||||
|
auto read = ops::ReadVariableOp(
|
||||||
|
root.WithOpName("read").WithControlDependencies(Output(case_node)), var,
|
||||||
|
DT_INT32);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(root.DoShapeInference(case_node));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(Rewrite(&graph));
|
||||||
|
|
||||||
|
// Verify the resultant graph has switch, merge and function call nodes.
|
||||||
|
int switch_count = 0;
|
||||||
|
int merge_count = 0;
|
||||||
|
int node_called_case_count = 0;
|
||||||
|
for (const auto* op : graph->op_nodes()) {
|
||||||
|
if (op->IsSwitch()) ++switch_count;
|
||||||
|
if (op->IsMerge()) ++merge_count;
|
||||||
|
if (op->name() == "case") ++node_called_case_count;
|
||||||
|
ASSERT_NE(op->type_string(), "Case");
|
||||||
|
}
|
||||||
|
// One switch for predicate and one for input (A).
|
||||||
|
ASSERT_EQ(switch_count, 2);
|
||||||
|
// One merge for the else/then branch (`branch_executed` node).
|
||||||
|
ASSERT_EQ(merge_count, 1);
|
||||||
|
// We keep a NoOp with the same name as original If node.
|
||||||
|
ASSERT_EQ(node_called_case_count, 1);
|
||||||
|
|
||||||
|
// Verify execution.
|
||||||
|
ClientSession session(root, SessionOptionsWithInlining());
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(-5));
|
||||||
|
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||||
|
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
|
||||||
|
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||||
|
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 11);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
|
||||||
|
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||||
|
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
|
||||||
|
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||||
|
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(31));
|
||||||
|
feeds.emplace(Output(initial_val.node()), Input::Initializer(10));
|
||||||
|
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LowerCaseOpTest, DoNotInlineLoweredFunction) {
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
|
||||||
|
FunctionDef x_times_two = test::function::XTimesTwo();
|
||||||
|
FunctionDef x_times_four = test::function::XTimesFour();
|
||||||
|
FunctionDef x_times_16 = test::function::XTimes16();
|
||||||
|
|
||||||
|
// Case branches can't be inlined.
|
||||||
|
(*x_times_two.mutable_attr())["_noinline"].set_b(true);
|
||||||
|
(*x_times_four.mutable_attr())["_noinline"].set_b(true);
|
||||||
|
(*x_times_16.mutable_attr())["_noinline"].set_b(true);
|
||||||
|
|
||||||
|
// Add test functions for the branches.
|
||||||
|
FunctionDefLibrary f_lib_proto;
|
||||||
|
*(f_lib_proto.add_function()) = x_times_two;
|
||||||
|
*(f_lib_proto.add_function()) = x_times_four;
|
||||||
|
*(f_lib_proto.add_function()) = x_times_16;
|
||||||
|
|
||||||
|
// Construct simple conditional that switches on `pred` and operates only on
|
||||||
|
// single input `A`.
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||||
|
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||||
|
auto branch_index =
|
||||||
|
ops::Placeholder(root.WithOpName("branch_index"), DT_INT32);
|
||||||
|
Node* written_case;
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder("case", "Case", &root.graph()->flib_def())
|
||||||
|
.Input(branch_index.node())
|
||||||
|
.Input(inputs)
|
||||||
|
.Attr("branches",
|
||||||
|
FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"}))
|
||||||
|
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||||
|
.Attr("Tout", {DT_INT32})
|
||||||
|
.Finalize(root.graph(), &written_case));
|
||||||
|
TF_ASSERT_OK(root.DoShapeInference(written_case));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(Rewrite(&graph));
|
||||||
|
|
||||||
|
// Verify that Case node was lowered but branch functions were not inlined.
|
||||||
|
int x_times_two_count = 0;
|
||||||
|
int x_times_four_count = 0;
|
||||||
|
int x_times_16_count = 0;
|
||||||
|
|
||||||
|
for (const auto* op : graph->op_nodes()) {
|
||||||
|
if (op->type_string() == x_times_two.signature().name()) {
|
||||||
|
x_times_two_count++;
|
||||||
|
}
|
||||||
|
if (op->type_string() == x_times_four.signature().name()) {
|
||||||
|
x_times_four_count++;
|
||||||
|
}
|
||||||
|
if (op->type_string() == x_times_16.signature().name()) {
|
||||||
|
x_times_16_count++;
|
||||||
|
}
|
||||||
|
ASSERT_NE(op->type_string(), "Case");
|
||||||
|
}
|
||||||
|
|
||||||
|
// One function for each branch.
|
||||||
|
ASSERT_EQ(x_times_two_count, 1);
|
||||||
|
ASSERT_EQ(x_times_four_count, 1);
|
||||||
|
ASSERT_EQ(x_times_16_count, 1);
|
||||||
|
|
||||||
|
// Verify execution.
|
||||||
|
ClientSession session(root, SessionOptionsWithInlining());
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(-2));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(0));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(1));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(2));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(branch_index.node()), Input::Initializer(31));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 160);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||||
@ -114,13 +115,13 @@ Status LowerFunctionalOpsPass::Run(
|
|||||||
? *keep_lowered_nodes_fetchable_
|
? *keep_lowered_nodes_fetchable_
|
||||||
: !HasArgsOrRetvals(*g);
|
: !HasArgsOrRetvals(*g);
|
||||||
|
|
||||||
// Lower all If and While ops that have the `kLowerUsingSwitchMergeAttr` attr
|
// Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
|
||||||
// set and inlines all function calls into the graph.
|
// attr set and inline all function calls into the graph.
|
||||||
// We start at `i` = 2 to skip the source and sink nodes.
|
// We start at `i` = 2 to skip the source and sink nodes.
|
||||||
// Note that `g->num_node_ids()` may change in the for body if a matching If
|
// Note that `g->num_node_ids()` may change in the for body if a matching If,
|
||||||
// or While node is lowered. Since new graph nodes are always added to the
|
// Case, While node is lowered. Since new graph nodes are always added to the
|
||||||
// end of the list of nodes it is ensured that nested If/While nodes will be
|
// end of the list of nodes it is ensured that nested If/Case/While nodes will
|
||||||
// lowered as well.
|
// be lowered as well.
|
||||||
for (int i = 2; i < g->num_node_ids(); ++i) {
|
for (int i = 2; i < g->num_node_ids(); ++i) {
|
||||||
Node* n = g->FindNodeId(i);
|
Node* n = g->FindNodeId(i);
|
||||||
if (n == nullptr) continue; // deleted node
|
if (n == nullptr) continue; // deleted node
|
||||||
@ -139,6 +140,9 @@ Status LowerFunctionalOpsPass::Run(
|
|||||||
if (n->type_string() == "If") {
|
if (n->type_string() == "If") {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
RewriteIfNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
RewriteIfNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||||
|
} else if (n->type_string() == "Case") {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RewriteCaseNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||||
} else if (n->type_string() == "While") {
|
} else if (n->type_string() == "While") {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable));
|
||||||
|
@ -14,9 +14,9 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
@ -89,9 +89,9 @@ class CondBuilder {
|
|||||||
Node* then_call_node_;
|
Node* then_call_node_;
|
||||||
Node* else_call_node_;
|
Node* else_call_node_;
|
||||||
// Merge node that has inputs from [pivot_t, pivot_f] and control edges from
|
// Merge node that has inputs from [pivot_t, pivot_f] and control edges from
|
||||||
// [^then_call_node_, ^else_call_node_]. This node will guarantee that if
|
// [^then_call_node_, ^else_call_node_]. This node will guarantee that even
|
||||||
// then/else branch functions do not have outputs, they still will be executed
|
// when then/else branch functions do not have outputs, they still will be
|
||||||
// for the side effects.
|
// executed for the side effects.
|
||||||
Node* branch_executed_node_;
|
Node* branch_executed_node_;
|
||||||
Graph* graph_;
|
Graph* graph_;
|
||||||
const FunctionLibraryDefinition& flib_;
|
const FunctionLibraryDefinition& flib_;
|
||||||
|
@ -58,6 +58,7 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
|||||||
*new std::unordered_map<string, Node::NodeClass>({
|
*new std::unordered_map<string, Node::NodeClass>({
|
||||||
// Keep in same order as NodeClass values
|
// Keep in same order as NodeClass values
|
||||||
REF_CLASS("Switch", NC_SWITCH),
|
REF_CLASS("Switch", NC_SWITCH),
|
||||||
|
REF_CLASS("_SwitchN", NC_SWITCH),
|
||||||
REF_CLASS("Merge", NC_MERGE),
|
REF_CLASS("Merge", NC_MERGE),
|
||||||
REF_CLASS("Enter", NC_ENTER),
|
REF_CLASS("Enter", NC_ENTER),
|
||||||
REF_CLASS("Exit", NC_EXIT),
|
REF_CLASS("Exit", NC_EXIT),
|
||||||
|
@ -150,7 +150,8 @@ bool IsControlFlow(const NodeDef& node) {
|
|||||||
node.op() == "LoopCond" ||
|
node.op() == "LoopCond" ||
|
||||||
node.op() == "Merge" ||
|
node.op() == "Merge" ||
|
||||||
node.op() == "NextIteration" ||
|
node.op() == "NextIteration" ||
|
||||||
node.op() == "Switch";
|
node.op() == "Switch" ||
|
||||||
|
node.op() == "_SwitchN";
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,7 +524,7 @@ bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
|
|||||||
|
|
||||||
bool IsSwitch(const NodeDef& node) {
|
bool IsSwitch(const NodeDef& node) {
|
||||||
const auto& op = node.op();
|
const auto& op = node.op();
|
||||||
return op == "Switch" || op == "RefSwitch";
|
return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsSymbolicGradient(const NodeDef& node) {
|
bool IsSymbolicGradient(const NodeDef& node) {
|
||||||
|
@ -1281,7 +1281,7 @@ Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
|
|||||||
node_map_->AddOutput(const_index->name(), output->name());
|
node_map_->AddOutput(const_index->name(), output->name());
|
||||||
} else {
|
} else {
|
||||||
// This is a control dependency (or an invalid edge since the
|
// This is a control dependency (or an invalid edge since the
|
||||||
// merge node has only 2 inputs): preserve them.
|
// merge node has only 2 outputs): preserve them.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/lower_case_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||||
@ -1157,6 +1158,8 @@ Status InlineFunctionCalls(const GrapplerItem& item,
|
|||||||
|
|
||||||
if (n->type_string() == "If") {
|
if (n->type_string() == "If") {
|
||||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
|
TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
|
||||||
|
} else if (n->type_string() == "Case") {
|
||||||
|
TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), flib_def, false));
|
||||||
} else if (n->type_string() == "While") {
|
} else if (n->type_string() == "While") {
|
||||||
TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
|
TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
|
||||||
}
|
}
|
||||||
|
@ -195,6 +195,7 @@ std::set<string> GetOpsFormatAgnostic() {
|
|||||||
"StridedSlice",
|
"StridedSlice",
|
||||||
"StridedSliceGrad",
|
"StridedSliceGrad",
|
||||||
"Switch",
|
"Switch",
|
||||||
|
"_SwitchN",
|
||||||
"Tile",
|
"Tile",
|
||||||
"TruncateDiv",
|
"TruncateDiv",
|
||||||
"TruncateMod",
|
"TruncateMod",
|
||||||
@ -1943,7 +1944,15 @@ class SwitchProcessor : public AgnosticNodeProcessor {
|
|||||||
: AgnosticNodeProcessor(opt_cxt) {}
|
: AgnosticNodeProcessor(opt_cxt) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::set<int> GetOutputPos() const override { return {0, 1}; }
|
std::set<int> GetOutputPos() const override {
|
||||||
|
std::set<int> output_pos;
|
||||||
|
const int num_outs =
|
||||||
|
node_->attr().count("num_outs") ? node_->attr().at("num_outs").i() : 2;
|
||||||
|
for (int i = 0; i < num_outs; i++) {
|
||||||
|
output_pos.insert(i);
|
||||||
|
}
|
||||||
|
return output_pos;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class TileProcessor : public AgnosticNodeProcessor {
|
class TileProcessor : public AgnosticNodeProcessor {
|
||||||
|
@ -460,7 +460,7 @@ std::vector<int> GetStackPushNodesToConvert(
|
|||||||
|
|
||||||
const std::unordered_set<string> op_types_to_traverse(
|
const std::unordered_set<string> op_types_to_traverse(
|
||||||
{"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
|
{"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
|
||||||
"Identity", "RefIdentity"});
|
"_SwitchN", "Identity", "RefIdentity"});
|
||||||
const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
|
const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
|
||||||
return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
|
return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
|
||||||
};
|
};
|
||||||
@ -753,6 +753,9 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
if (!IsSwitch(node)) {
|
if (!IsSwitch(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (node.op() == "_SwitchN") { // _SwitchN not used in loop control flow.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
|
if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -798,8 +801,8 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
if (IsMerge(*dead.node)) {
|
if (IsMerge(*dead.node)) {
|
||||||
const int num_data_inputs = dead.node->attr().at("N").i();
|
const int num_data_inputs = dead.node->attr().at("N").i();
|
||||||
if (num_data_inputs > 2) {
|
if (num_data_inputs > 2) {
|
||||||
// This never happens in practice, so we'll just skip these to
|
// This can happen with _SwitchN/Merge (Case lowering). We skip these
|
||||||
// simplify the code for now.
|
// to simplify the code for now.
|
||||||
found_node_to_preserve = true;
|
found_node_to_preserve = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -877,11 +880,15 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
}
|
}
|
||||||
// Remove dead data input.
|
// Remove dead data input.
|
||||||
const std::set<int>& dead_inputs = itr.second;
|
const std::set<int>& dead_inputs = itr.second;
|
||||||
|
CHECK_LE(dead_inputs.size(), 1);
|
||||||
|
// (This loop would delete >1 items possibly in the wrong order.)
|
||||||
for (int index : dead_inputs) {
|
for (int index : dead_inputs) {
|
||||||
dead_node->mutable_input()->DeleteSubrange(index, 1);
|
dead_node->mutable_input()->DeleteSubrange(index, 1);
|
||||||
}
|
}
|
||||||
// Turn Merge into Identity only if we deleted data inputs.
|
// Turn Merge into Identity only if we deleted the other data input.
|
||||||
if (!dead_inputs.empty()) {
|
if (!dead_inputs.empty()) {
|
||||||
|
const int num_data_inputs = dead_node->attr().at("N").i();
|
||||||
|
CHECK_EQ(num_data_inputs, dead_inputs.size() + 1);
|
||||||
dead_node->set_op("Identity");
|
dead_node->set_op("Identity");
|
||||||
dead_node->mutable_attr()->erase("N");
|
dead_node->mutable_attr()->erase("N");
|
||||||
}
|
}
|
||||||
|
@ -39,12 +39,30 @@ void SwitchOp::Compute(OpKernelContext* context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SwitchNOp::Compute(OpKernelContext* context) {
|
||||||
|
const Tensor& output_index_t = context->input(1);
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_index_t.shape()),
|
||||||
|
errors::InvalidArgument("The second input must be a scalar, "
|
||||||
|
"but it has shape ",
|
||||||
|
output_index_t.shape().DebugString()));
|
||||||
|
int output_index = output_index_t.scalar<int>()();
|
||||||
|
if (output_index < 0 || output_index >= num_outputs()) {
|
||||||
|
output_index = num_outputs() - 1;
|
||||||
|
}
|
||||||
|
context->set_output(output_index, context->input(0));
|
||||||
|
}
|
||||||
|
|
||||||
#define REGISTER_CPU_SWITCH(type) \
|
#define REGISTER_CPU_SWITCH(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.HostMemory("pred") \
|
.HostMemory("pred") \
|
||||||
.TypeConstraint<type>("T"), \
|
.TypeConstraint<type>("T"), \
|
||||||
SwitchOp)
|
SwitchOp) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.HostMemory("output_index") \
|
||||||
|
.TypeConstraint<type>("T"), \
|
||||||
|
SwitchNOp)
|
||||||
|
|
||||||
#define REGISTER_CPU_REF_SWITCH(type) \
|
#define REGISTER_CPU_REF_SWITCH(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
||||||
@ -58,7 +76,12 @@ void SwitchOp::Compute(OpKernelContext* context) {
|
|||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.HostMemory("pred") \
|
.HostMemory("pred") \
|
||||||
.TypeConstraint<type>("T"), \
|
.TypeConstraint<type>("T"), \
|
||||||
SwitchOp)
|
SwitchOp) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.HostMemory("output_index") \
|
||||||
|
.TypeConstraint<type>("T"), \
|
||||||
|
SwitchNOp)
|
||||||
|
|
||||||
#define REGISTER_GPU_REF_SWITCH(type) \
|
#define REGISTER_GPU_REF_SWITCH(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
||||||
@ -96,7 +119,14 @@ TF_CALL_variant(REGISTER_GPU_SWITCH);
|
|||||||
.HostMemory("output_false") \
|
.HostMemory("output_false") \
|
||||||
.HostMemory("output_true") \
|
.HostMemory("output_true") \
|
||||||
.TypeConstraint<type>("T"), \
|
.TypeConstraint<type>("T"), \
|
||||||
SwitchOp)
|
SwitchOp) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.HostMemory("data") \
|
||||||
|
.HostMemory("output_index") \
|
||||||
|
.HostMemory("outputs") \
|
||||||
|
.TypeConstraint<type>("T"), \
|
||||||
|
SwitchNOp)
|
||||||
|
|
||||||
#define REGISTER_GPU_HOST_REF_KERNEL(type) \
|
#define REGISTER_GPU_HOST_REF_KERNEL(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
|
||||||
|
@ -46,6 +46,21 @@ class SwitchOp : public OpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// An n-way switch op has two inputs and N outputs. It forwards the value of
|
||||||
|
// Input:0 to the output specified by Input:1. Input:1 is an integer tensor.
|
||||||
|
// Input:0 is forwarded to output:0 if Input:1 is 0, to output:1 if 1, and so
|
||||||
|
// forth. If Input:1 is <0 or >=num_outputs(), Input:0 is forwarded to
|
||||||
|
// output:num_outputs()-1.
|
||||||
|
class SwitchNOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit SwitchNOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
void Compute(OpKernelContext* context) override;
|
||||||
|
bool IsExpensive() override { return false; }
|
||||||
|
~SwitchNOp() override {}
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(SwitchNOp);
|
||||||
|
};
|
||||||
|
|
||||||
// A merge op has n inputs and two outputs. It forwards the value of the
|
// A merge op has n inputs and two outputs. It forwards the value of the
|
||||||
// first input that becomes available to its first output, and the
|
// first input that becomes available to its first output, and the
|
||||||
// index of the first input to its second output.
|
// index of the first input to its second output.
|
||||||
|
@ -24,6 +24,7 @@ using shape_inference::ShapeHandle;
|
|||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Status SwitchShape(InferenceContext* c) {
|
Status SwitchShape(InferenceContext* c) {
|
||||||
ShapeHandle unused;
|
ShapeHandle unused;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
@ -39,6 +40,27 @@ Status SwitchShape(InferenceContext* c) {
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status SwitchNShape(InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
ShapeHandle out = c->input(0);
|
||||||
|
int num_outs;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("num_outs", &num_outs));
|
||||||
|
for (int i = 0; i < num_outs; i++) {
|
||||||
|
c->set_output(i, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle resource shape / dtype.
|
||||||
|
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||||
|
if (handle_data != nullptr) {
|
||||||
|
for (int i = 0; i < num_outs; i++) {
|
||||||
|
c->set_output_handle_shapes_and_types(i, *handle_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
REGISTER_OP("Switch")
|
REGISTER_OP("Switch")
|
||||||
@ -58,6 +80,14 @@ REGISTER_OP("RefSwitch")
|
|||||||
.SetAllowsUninitializedInput()
|
.SetAllowsUninitializedInput()
|
||||||
.SetShapeFn(SwitchShape);
|
.SetShapeFn(SwitchShape);
|
||||||
|
|
||||||
|
REGISTER_OP("_SwitchN")
|
||||||
|
.Input("data: T")
|
||||||
|
.Input("output_index: int32")
|
||||||
|
.Output("outputs: num_outs * T")
|
||||||
|
.Attr("num_outs: int >= 1")
|
||||||
|
.Attr("T: type")
|
||||||
|
.SetShapeFn(SwitchNShape);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
REGISTER_OP("RefSelect")
|
REGISTER_OP("RefSelect")
|
||||||
.Input("index: int32")
|
.Input("index: int32")
|
||||||
|
@ -49,6 +49,27 @@ TEST(ControlFlowOpsTest, Merge_ShapeFn) {
|
|||||||
INFER_OK(op, "[2,1];[2,1];[2,1]", "in0;[]");
|
INFER_OK(op, "[2,1];[2,1];[2,1]", "in0;[]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ControlFlowOpsTest, SwitchN_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("_SwitchN");
|
||||||
|
|
||||||
|
int n = 5;
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("test", "_SwitchN")
|
||||||
|
.Input({"d", 0, DT_FLOAT})
|
||||||
|
.Input({"bi", 0, DT_INT32})
|
||||||
|
.Attr("num_outs", n)
|
||||||
|
.Finalize(&op.node_def));
|
||||||
|
|
||||||
|
// Non-scalar output_index.
|
||||||
|
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[2]");
|
||||||
|
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1]");
|
||||||
|
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[?]");
|
||||||
|
// The second input should always be scalar. Outputs are 5x the first input.
|
||||||
|
INFER_OK(op, "?;?", "in0;in0;in0;in0;in0");
|
||||||
|
INFER_OK(op, "[2,?];?", "in0;in0;in0;in0;in0");
|
||||||
|
INFER_OK(op, "[2,?];[]", "in0;in0;in0;in0;in0");
|
||||||
|
INFER_OK(op, "[2,3];[]", "in0;in0;in0;in0;in0");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ControlFlowOpsTest, RefSelect_ShapeFn) {
|
TEST(ControlFlowOpsTest, RefSelect_ShapeFn) {
|
||||||
ShapeInferenceTestOp op("RefSelect");
|
ShapeInferenceTestOp op("RefSelect");
|
||||||
|
|
||||||
|
@ -1177,6 +1177,84 @@ class FunctionalOpsCaseTest(test.TestCase):
|
|||||||
self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default
|
self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default
|
||||||
self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default
|
self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.disable_xla("Don't lower for XLA")
|
||||||
|
def testSkipEagerCaseLoweringPreservesNameForFetch(self):
|
||||||
|
for use_gpu in (True, False):
|
||||||
|
def Run(branch, x, fetch_by_name, use_gpu=use_gpu):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def two(x):
|
||||||
|
return -1, x * 2
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def three(x):
|
||||||
|
return 0, x * 3
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def four(x):
|
||||||
|
return 1, x * 4
|
||||||
|
|
||||||
|
outputs = gen_functional_ops.case(branch, input=[x],
|
||||||
|
Tout=[dtypes.int32, dtypes.float32],
|
||||||
|
branches=[two, three, four],
|
||||||
|
name="my_case")
|
||||||
|
|
||||||
|
# `outputs` is the list of output tensors of the Case op. We
|
||||||
|
# arbitrarily choose the 0th tensor to get the Case op and set the
|
||||||
|
# lowering attribute on it.
|
||||||
|
outputs[0].op._set_attr("_lower_using_switch_merge",
|
||||||
|
attr_value_pb2.AttrValue(b=True))
|
||||||
|
outputs = array_ops.identity_n(outputs)
|
||||||
|
with self.session(graph=g, use_gpu=use_gpu) as sess:
|
||||||
|
return sess.run("my_case:1" if fetch_by_name else outputs[1])
|
||||||
|
|
||||||
|
self.assertAllEqual(2 * 1., Run(0, 1., False))
|
||||||
|
self.assertAllEqual(2 * 1., Run(0, 1., True))
|
||||||
|
self.assertAllEqual(3 * 7., Run(1, 7., False))
|
||||||
|
self.assertAllEqual(3 * 7., Run(1, 7., True))
|
||||||
|
self.assertAllEqual(4 * -3., Run(2, -3., False))
|
||||||
|
self.assertAllEqual(4 * -3., Run(2, -3., True))
|
||||||
|
self.assertAllEqual(4 * -4., Run(7, -4., False)) # >= N default
|
||||||
|
self.assertAllEqual(4 * -4., Run(7, -4., True)) # >= N default
|
||||||
|
self.assertAllEqual(4 * -5., Run(-1, -5., False)) # <0 default
|
||||||
|
self.assertAllEqual(4 * -5., Run(-1, -5., True)) # <0 default
|
||||||
|
|
||||||
|
@test_util.disable_xla("Don't lower for XLA")
|
||||||
|
def testCaseLowering(self):
|
||||||
|
for use_gpu in (True, False):
|
||||||
|
@eager_function.defun
|
||||||
|
def Run(branch, x):
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def two(x):
|
||||||
|
return -1, x * 2
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def three(x):
|
||||||
|
return 0, x * 3
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def four(x):
|
||||||
|
return 1, x * 4
|
||||||
|
|
||||||
|
outputs = gen_functional_ops.case(branch, input=[x],
|
||||||
|
Tout=[dtypes.int32, dtypes.float32],
|
||||||
|
branches=[two, three, four])
|
||||||
|
|
||||||
|
# `outputs` is the list of output tensors of the Case op. We
|
||||||
|
# arbitrarily choose the 0th tensor to get the Case op and set the
|
||||||
|
# lowering attribute on it.
|
||||||
|
outputs[0].op._set_attr("_lower_using_switch_merge",
|
||||||
|
attr_value_pb2.AttrValue(b=True))
|
||||||
|
outputs = array_ops.identity_n(outputs)
|
||||||
|
return outputs[1]
|
||||||
|
|
||||||
|
with ops.device(test.gpu_device_name() if use_gpu else "CPU:0"):
|
||||||
|
self.assertAllEqual(2 * 1., self.evaluate(Run(0, 1.)))
|
||||||
|
self.assertAllEqual(3 * 7., self.evaluate(Run(1, 7.)))
|
||||||
|
self.assertAllEqual(4 * -3., self.evaluate(Run(2, -3.)))
|
||||||
|
self.assertAllEqual(4 * -4., self.evaluate(Run(7, -4.))) # >=N default
|
||||||
|
self.assertAllEqual(4 * -5., self.evaluate(Run(-1, -5.))) # <0 default
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -510,7 +510,7 @@ def _make_output_composite_tensors_match(op_type, branch_graphs):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Cannot reconcile {op_name} {output_idx}-th outputs:\n"
|
"Cannot reconcile {op_name} {output_idx}-th outputs:\n"
|
||||||
" outputs from all branches: {outputs}".format(
|
" outputs from all branches: {outputs}".format(
|
||||||
op_name="tf.cond" if op_type == _COND else "tf.case",
|
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
|
||||||
output_idx=output_idx,
|
output_idx=output_idx,
|
||||||
outputs=branch_outs))
|
outputs=branch_outs))
|
||||||
|
|
||||||
@ -534,9 +534,9 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
|
|||||||
for output_idx, branch_outs in enumerate(
|
for output_idx, branch_outs in enumerate(
|
||||||
zip(*branch_outputs_flat_with_composites)):
|
zip(*branch_outputs_flat_with_composites)):
|
||||||
if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
|
if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
|
||||||
raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n"
|
raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
|
||||||
" branches returned: {outputs}".format(
|
" branches returned: {outputs}".format(
|
||||||
op_name="tf.cond" if op_type == _COND else "tf.case",
|
op_name="cond" if op_type == _COND else "switch_case",
|
||||||
output_idx=output_idx,
|
output_idx=output_idx,
|
||||||
outputs=branch_outs))
|
outputs=branch_outs))
|
||||||
if isinstance(branch_outs[0], ops.IndexedSlices):
|
if isinstance(branch_outs[0], ops.IndexedSlices):
|
||||||
@ -632,7 +632,7 @@ def _check_same_outputs(op_type, graphs):
|
|||||||
b0_name="true_fn" if op_type == _COND else "branches[0]",
|
b0_name="true_fn" if op_type == _COND else "branches[0]",
|
||||||
bn_name=("false_fn" if op_type == _COND else
|
bn_name=("false_fn" if op_type == _COND else
|
||||||
"branches[{}]".format(branch_idx)),
|
"branches[{}]".format(branch_idx)),
|
||||||
op_name="tf.cond" if op_type == _COND else "tf.case",
|
op_name="tf.cond" if op_type == _COND else "tf.switch_case",
|
||||||
b0_out=graphs[0].structured_outputs,
|
b0_out=graphs[0].structured_outputs,
|
||||||
bn_out=graphs[branch_idx].structured_outputs,
|
bn_out=graphs[branch_idx].structured_outputs,
|
||||||
detail=error_detail))
|
detail=error_detail))
|
||||||
@ -817,8 +817,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"):
|
|||||||
|
|
||||||
@ops.RegisterGradient("Case")
|
@ops.RegisterGradient("Case")
|
||||||
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
||||||
"""The gradient of a Case op produced (w/ branch_index) by tf.case."""
|
"""The gradient of a Case op produced by tf.switch_case."""
|
||||||
# Get the if operator (this logic handles the case where op is a MockOp)
|
# Get the Case operator (this logic handles the case where op is a MockOp)
|
||||||
case_op = op.outputs[0].op
|
case_op = op.outputs[0].op
|
||||||
branch_graphs = _get_func_graphs(case_op)
|
branch_graphs = _get_func_graphs(case_op)
|
||||||
assert branch_graphs
|
assert branch_graphs
|
||||||
@ -892,7 +892,7 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
|
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
|
||||||
|
|
||||||
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
|
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
|
||||||
branches_grad_inputs)
|
branches_grad_inputs, name="gradient")
|
||||||
|
|
||||||
# The predicate has no gradient.
|
# The predicate has no gradient.
|
||||||
return [None] + outputs
|
return [None] + outputs
|
||||||
@ -935,14 +935,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
|||||||
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
# TODO(b/110167197) this requires Case to have at least 1 output
|
# TODO(b/110167197): this requires Case to have at least 1 output
|
||||||
case_op = tensors[0].op
|
case_op = tensors[0].op
|
||||||
# TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode.
|
util.maybe_set_lowering_attr(case_op)
|
||||||
# util.maybe_set_lowering_attr(case_op)
|
|
||||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||||
|
|
||||||
# Return identities for each output of the Case op, rather than the output of
|
# Return identities for each output of the Case op, rather than the output of
|
||||||
# the Case op directly. This makes pruning work if the output of select_case()
|
# the Case op directly. This makes pruning work if the output of switch_case()
|
||||||
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
||||||
# outputs, which if fetched will cause all ops in the taken branch to be run
|
# outputs, which if fetched will cause all ops in the taken branch to be run
|
||||||
# (since it takes all merge ops as input). After lowering, each output
|
# (since it takes all merge ops as input). After lowering, each output
|
||||||
|
@ -19,11 +19,14 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -39,7 +42,9 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
@ -933,7 +938,10 @@ class DataTypesTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class IndexedCaseTest(test_util.TensorFlowTestCase):
|
class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def make_name(self):
|
||||||
|
return self.id().split(".")[-1].replace("(", "_").replace(")", "")
|
||||||
|
|
||||||
def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self):
|
def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self):
|
||||||
nbranches = 5
|
nbranches = 5
|
||||||
@ -947,53 +955,55 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
|
|||||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||||
self.assertEqual(bi * 10, self.evaluate(case_out))
|
self.assertEqual(bi * 10, self.evaluate(case_out))
|
||||||
|
|
||||||
def testCase(self):
|
@parameterized.parameters((0,), (2,), (3,))
|
||||||
|
def testCase(self, bi):
|
||||||
nbranches = 5
|
nbranches = 5
|
||||||
|
|
||||||
def make_func(bi):
|
def make_func(bi):
|
||||||
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
branches = [(i, make_func(i)) for i in range(nbranches)]
|
branches = [(i, make_func(i)) for i in range(nbranches)]
|
||||||
for bi in 0, 2, 3:
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
case_out = control_flow_ops.switch_case(
|
||||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
branch_index, branches, name=self.make_name())
|
||||||
self.assertEqual(bi * 10., self.evaluate(case_out))
|
self.assertEqual(bi * 10., self.evaluate(case_out))
|
||||||
|
|
||||||
def testCase_withDefault(self):
|
@parameterized.parameters((-1,), (2,), (4,), (5,), (6,))
|
||||||
|
def testCase_withDefault(self, bi):
|
||||||
nbranches = 5
|
nbranches = 5
|
||||||
|
|
||||||
def make_func(bi):
|
def make_func(bi):
|
||||||
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
branches = [(i, make_func(i)) for i in range(nbranches)]
|
branches = [(i, make_func(i)) for i in range(nbranches)]
|
||||||
for bi in -1, 2, nbranches:
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
case_out = control_flow_ops.switch_case(
|
||||||
case_out = control_flow_ops.switch_case(
|
branch_index, branches, default=make_func(6), name=self.make_name())
|
||||||
branch_index, branches, default=make_func(6))
|
if bi < 0 or bi >= nbranches:
|
||||||
if bi < 0 or bi >= nbranches:
|
expected = 60.
|
||||||
expected = 60.
|
else:
|
||||||
else:
|
expected = bi * 10.
|
||||||
expected = bi * 10.
|
self.assertEqual(expected, self.evaluate(case_out))
|
||||||
self.assertEqual(expected, self.evaluate(case_out))
|
|
||||||
|
|
||||||
def testCase_dictWithDefault(self):
|
@parameterized.parameters((-1,), (0,), (3,), (5,))
|
||||||
|
def testCase_dictWithDefault(self, bi):
|
||||||
nbranches = 5
|
nbranches = 5
|
||||||
|
|
||||||
def make_func(bi):
|
def make_func(bi):
|
||||||
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
branches = {i: make_func(i) for i in range(nbranches)}
|
branches = {i: make_func(i) for i in range(nbranches)}
|
||||||
for bi in -1, 0, 3, nbranches:
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
case_out = control_flow_ops.switch_case(
|
||||||
case_out = control_flow_ops.switch_case(
|
branch_index, branches, default=make_func(6), name=self.make_name())
|
||||||
branch_index, branches, default=make_func(6))
|
if bi < 0 or bi >= nbranches:
|
||||||
if bi < 0 or bi >= nbranches:
|
expected = 60.
|
||||||
expected = 60.
|
else:
|
||||||
else:
|
expected = bi * 10.
|
||||||
expected = bi * 10.
|
self.assertEqual(expected, self.evaluate(case_out))
|
||||||
self.assertEqual(expected, self.evaluate(case_out))
|
|
||||||
|
|
||||||
def testCase_gradient(self):
|
@parameterized.parameters((-1,), (1,), (4,), (5,))
|
||||||
|
def testCase_gradient(self, bi):
|
||||||
nbranches = 5
|
nbranches = 5
|
||||||
inputs = [
|
inputs = [
|
||||||
array_ops.constant(float(bi), name="br{}_in".format(bi))
|
array_ops.constant(float(bi), name="br{}_in".format(bi))
|
||||||
@ -1005,22 +1015,22 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
||||||
|
|
||||||
for bi in -1, 1, 4, nbranches:
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
with backprop.GradientTape() as tape:
|
||||||
with backprop.GradientTape() as tape:
|
for x in inputs:
|
||||||
for x in inputs:
|
tape.watch(x)
|
||||||
tape.watch(x)
|
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
out_grad = 3.
|
||||||
out_grad = 3.
|
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||||
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
|
||||||
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
|
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
|
||||||
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
|
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
|
||||||
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
|
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||||
self.assertEqual(len(expected_grads), len(actual_grads))
|
for expected, actual in zip(expected_grads, actual_grads):
|
||||||
for expected, actual in zip(expected_grads, actual_grads):
|
self.assertEqual(expected, self.evaluate(actual))
|
||||||
self.assertEqual(expected, self.evaluate(actual))
|
|
||||||
|
|
||||||
def testCase_gradient_diffShapedIntermediates(self):
|
@parameterized.parameters((-2,), (2,), (5,))
|
||||||
|
def testCase_gradient_diffShapedIntermediates(self, bi):
|
||||||
nbranches = 5
|
nbranches = 5
|
||||||
inputs = [
|
inputs = [
|
||||||
array_ops.constant(
|
array_ops.constant(
|
||||||
@ -1038,34 +1048,105 @@ class IndexedCaseTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
||||||
|
|
||||||
for bi in -1, 2, nbranches:
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
branch_index = array_ops.placeholder_with_default(bi, [])
|
with backprop.GradientTape() as tape:
|
||||||
with backprop.GradientTape() as tape:
|
for x in inputs:
|
||||||
for x in inputs:
|
tape.watch(x)
|
||||||
tape.watch(x)
|
case_out = control_flow_ops.switch_case(
|
||||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
branch_index, branches, name=self.make_name())
|
||||||
out_grad = 3.
|
out_grad = 3.
|
||||||
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||||
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
|
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
|
||||||
expected_grads = []
|
expected_grads = []
|
||||||
for input_idx in range(nbranches):
|
for input_idx in range(nbranches):
|
||||||
if used_bi == input_idx:
|
if used_bi == input_idx:
|
||||||
with backprop.GradientTape() as tape:
|
with backprop.GradientTape() as tape:
|
||||||
tape.watch(inputs[used_bi])
|
tape.watch(inputs[used_bi])
|
||||||
y = make_func(used_bi)()
|
y = make_func(used_bi)()
|
||||||
expected_grads.append(
|
expected_grads.append(
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
|
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
|
||||||
else:
|
else:
|
||||||
expected_grads.append(None if context.executing_eagerly() else [0.] *
|
expected_grads.append(None if context.executing_eagerly() else [0.] *
|
||||||
(input_idx + 1))
|
(input_idx + 1))
|
||||||
|
|
||||||
self.assertEqual(len(expected_grads), len(actual_grads))
|
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||||
for expected, actual in zip(expected_grads, actual_grads):
|
for expected, actual in zip(expected_grads, actual_grads):
|
||||||
if expected is None:
|
if expected is None:
|
||||||
self.assertIsNone(actual)
|
self.assertIsNone(actual)
|
||||||
else:
|
else:
|
||||||
self.assertAllEqual(expected, self.evaluate(actual))
|
self.assertAllEqual(expected, self.evaluate(actual))
|
||||||
|
|
||||||
|
@test_util.run_gpu_only
|
||||||
|
@test_util.disable_xla("Wants RunMetadata")
|
||||||
|
def testParallelExecution(self):
|
||||||
|
"""Verify disjoint branches across while iterations are run in parallel."""
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
nbranches = 7
|
||||||
|
matrices = array_ops.unstack( # Ensure all are ready before while.
|
||||||
|
array_ops.matrix_diag(
|
||||||
|
random_ops.random_uniform([nbranches, 8, 512]) + 1e-3))
|
||||||
|
|
||||||
|
def make_branch(i, mat, name):
|
||||||
|
def branch_fn():
|
||||||
|
next_i = i + 1
|
||||||
|
with ops.device("gpu:0"):
|
||||||
|
return next_i, math_ops.reduce_sum(
|
||||||
|
linalg_ops.cholesky(mat, name=name + "_Cholesky"))
|
||||||
|
return branch_fn
|
||||||
|
|
||||||
|
def make_branches(i):
|
||||||
|
return [make_branch(i, matrices[bi], "br{}".format(bi))
|
||||||
|
for bi in range(nbranches)]
|
||||||
|
|
||||||
|
def cond(i, _):
|
||||||
|
return i < nbranches
|
||||||
|
|
||||||
|
def body(i, result):
|
||||||
|
with ops.device("cpu:0"):
|
||||||
|
next_i, branch_out = control_flow_ops.switch_case(i, make_branches(i))
|
||||||
|
return next_i, result + branch_out
|
||||||
|
|
||||||
|
_, result = control_flow_ops.while_loop(cond, body, [0, 0.])
|
||||||
|
|
||||||
|
run_metadata = config_pb2.RunMetadata()
|
||||||
|
run_options = config_pb2.RunOptions(
|
||||||
|
trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||||
|
config = config_pb2.ConfigProto(
|
||||||
|
allow_soft_placement=False, log_device_placement=True)
|
||||||
|
|
||||||
|
with session.Session(config=config, graph=g) as sess:
|
||||||
|
_ = sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||||
|
chol_node_stats = []
|
||||||
|
for dev_stats in run_metadata.step_stats.dev_stats:
|
||||||
|
for node_stats in dev_stats.node_stats:
|
||||||
|
if (node_stats.node_name.endswith("Cholesky") and
|
||||||
|
node_stats.all_start_nanos > 0):
|
||||||
|
chol_node_stats.append(node_stats)
|
||||||
|
|
||||||
|
self.assertLen(chol_node_stats, nbranches)
|
||||||
|
|
||||||
|
chol_node_stats = sorted(chol_node_stats, key=lambda stats: stats.node_name)
|
||||||
|
op_start_nanos = [
|
||||||
|
stats.all_start_nanos for stats in chol_node_stats
|
||||||
|
]
|
||||||
|
op_end_nanos = [
|
||||||
|
stats.all_start_nanos + stats.op_end_rel_nanos
|
||||||
|
for stats in chol_node_stats
|
||||||
|
]
|
||||||
|
|
||||||
|
def overlap(range1, range2):
|
||||||
|
s1, e1 = range1
|
||||||
|
s2, e2 = range2
|
||||||
|
if s1 < s2:
|
||||||
|
return 0 if s2 > e1 else e1 - s2
|
||||||
|
return 0 if s1 > e2 else e2 - s1
|
||||||
|
|
||||||
|
timespans = list(zip(op_start_nanos, op_end_nanos))
|
||||||
|
overlaps_chol0 = [overlap(timespans[0], r2) for r2 in timespans[1:]]
|
||||||
|
# There are nbranches-1 overlaps, sometimes all nonzero, but we
|
||||||
|
# conservatively check for at least one here, to avoid test flakiness.
|
||||||
|
self.assertGreater(np.count_nonzero(overlaps_chol0), 0)
|
||||||
|
|
||||||
def testCase_validateIndicesContiguous(self):
|
def testCase_validateIndicesContiguous(self):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user