diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 1f23c0880db..e9a5257cac3 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -846,24 +846,42 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); - Predicate* true_switch; - TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( - pred_edge->src(), pred_edge->src_output(), - /*must_be_true=*/true, &true_switch)); + if (n->num_outputs() == 2) { + Predicate* true_switch; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + pred_edge->src(), pred_edge->src_output(), + /*must_be_true=*/true, &true_switch)); - Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); - // Output 0 is alive iff all inputs are alive and the condition is false. - input_preds.push_back(false_switch); - SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds), - should_revisit); - input_preds.pop_back(); + // Output 0 is alive iff all inputs are alive and the condition is false. + input_preds.push_back(false_switch); + SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); + input_preds.pop_back(); - // Output 1 is alive iff all inputs are alive and the condition is true. - input_preds.push_back(true_switch); - SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds), - should_revisit); - input_preds.pop_back(); + // Output 1 is alive iff all inputs are alive and the condition is true. + input_preds.push_back(true_switch); + SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); + input_preds.pop_back(); + } else { // N-way switch case. Exactly one of N branches is alive. + Predicate* branch_pred; + for (int i = 0; i < n->num_outputs() - 1; i++) { + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, i, /*must_be_true=*/false, &branch_pred)); + input_preds.push_back(branch_pred); + SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); + input_preds.pop_back(); + input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred)); + } + // The default (last) branch does not need its own symbol, is simply the + // nor of all other branches. + SetPredicate(n, n->num_outputs() - 1, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); + } // Control is alive iff all inputs are alive. SetPredicate(n, Graph::kControlSlot, diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ffba9e040a1..74bb6266781 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3171,6 +3171,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/local_device.h", "common_runtime/lower_function_call_op.h", "common_runtime/lower_if_op.h", + "common_runtime/lower_case_op.h", "common_runtime/lower_functional_ops.h", "common_runtime/lower_while_op.h", "common_runtime/memory_types.h", @@ -3232,6 +3233,7 @@ tf_cuda_library( "common_runtime/inspecting_placer.h", "common_runtime/isolate_placer_inspection_required_ops_pass.cc", "common_runtime/local_device.cc", + "common_runtime/lower_case_op.cc", "common_runtime/lower_function_call_op.cc", "common_runtime/lower_functional_ops.cc", "common_runtime/lower_if_op.cc", @@ -5326,6 +5328,30 @@ tf_cc_tests( ], ) +tf_cc_tests( + name = "common_runtime_lower_case_op_test", + size = "small", + srcs = ["common_runtime/lower_case_op_test.cc"], + deps = [ + ":all_kernels", + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":lib", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + ], +) + tf_cc_tests( name = "common_runtime_lower_while_op_test", size = "small", diff --git a/tensorflow/core/common_runtime/lower_case_op.cc b/tensorflow/core/common_runtime/lower_case_op.cc new file mode 100644 index 00000000000..f85dc14231d --- /dev/null +++ b/tensorflow/core/common_runtime/lower_case_op.cc @@ -0,0 +1,300 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/lower_case_op.h" + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { + +using NodeOut = NodeBuilder::NodeOut; + +constexpr const char* const kLowerAsMultiDeviceFunctionAttr = + LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr; + +// Convenience builder to make it easy to construct a case with a single +// function call in each branch. This first converts the Case node +// into switches (for inputs) and merges (for outputs) around a function call +// per branch. +class CaseBuilder { + public: + // Create a CaseBuilder to create the lowered form of `case` with branch + // functions identified by `branch_fn_names` in the `graph`. The functions + // should be available in `flib`. + CaseBuilder(Node* case_op, const std::vector& branch_fn_names, + const FunctionLibraryDefinition& flib, bool keep_node_fetchable, + Graph* graph); + + // Constructs the basic conditional control flow using switch and merge nodes. + Status CreatePivotNodes(); + + // Adds the inputs from the if node to the merge nodes of the lowered if. + Status AddInputs(); + + // Adds the outputs from the if node to the merge nodes of the lowered if. + // Note: no inputs can be added once outputs are added as the then and else + // nodes are finalized while adding outputs. + Status AddOutputs(); + + // Builds an identity node with the same outputs as Case. + Status BuildLoweredCaseOutput(); + + private: + // Returns unique name containing the name of the Case op being rewritten + // (name_), infix and a suffix to ensure it is unique within the graph. + string NewName(const string& infix); + + // Adds input to both the then and else nodes from src:src_output. + Status AddInput(Node* src, int src_output); + + // The merged outputs of the then and else nodes. + std::vector outputs_; + + // The node that dominates all execution of the then and else body nodes. + Node* control_predecessor_; + // The original Case op. + Node* case_op_; + // The node with the same name as the original Case op: + // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true' + // and if the original Case op had non-zero data outputs. + // (b) NoOp node with control edge from 'branch_executed_node_' otherwise. + Node* lowered_case_output_; + // The branch selector of the case. + OutputTensor branch_index_; + int num_branches_; + // Nodes corresponding to pivot branch of branch_index _SwitchN, which is + // the pivot node that dominates all nodes in the i'th branch. + std::vector pivots_; + std::vector call_nodes_; + // Merge node that has inputs from each of pivots_ and control edges from + // [^call_node for call_node in call_nodes_]. This node will guarantee that + // even when branch functions do not have outputs, they still will be executed + // for the side effects. + Node* branch_executed_node_; + Graph* graph_; + const FunctionLibraryDefinition& flib_; + string name_; + bool keep_node_fetchable_; + + NodeDebugInfo debug_info_; + std::vector branch_call_builders_; +}; + +CaseBuilder::CaseBuilder(Node* case_op, + const std::vector& branch_fn_names, + const FunctionLibraryDefinition& flib, + bool keep_node_fetchable, Graph* graph) + : case_op_(case_op), + num_branches_(branch_fn_names.size()), + graph_(graph), + flib_(flib), + name_(case_op->name()), + keep_node_fetchable_(keep_node_fetchable), + debug_info_(*case_op_) { + branch_call_builders_.reserve(num_branches_); + for (int b = 0; b < num_branches_; b++) { + branch_call_builders_.emplace_back(NewName(strings::StrCat("branch", b)), + branch_fn_names[b], graph->op_registry(), + &debug_info_); + branch_call_builders_[b].Device(case_op_->requested_device()); + branch_call_builders_[b].Attr(kLowerAsMultiDeviceFunctionAttr, true); + } + TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_)); +} + +Status CaseBuilder::CreatePivotNodes() { + // Construct the basic case body (consisting of feeding in the val to + // create pivot nodes). + Node* branch_index; + TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_index"), "_SwitchN", + graph_->op_registry(), &debug_info_) + .Input(NodeOut(branch_index_)) + .Input(NodeOut(branch_index_)) + .Attr("num_outs", num_branches_) + .Device(case_op_->requested_device()) + .Finalize(graph_, &branch_index)); + control_predecessor_ = branch_index; + pivots_.resize(num_branches_, nullptr); + for (int b = 0; b < num_branches_; b++) { + TF_RETURN_IF_ERROR(NodeBuilder(NewName(strings::StrCat("pivot_", b)), + "Identity", graph_->op_registry(), + &debug_info_) + .Input(branch_index, b) + .Device(case_op_->requested_device()) + .Finalize(graph_, &pivots_[b])); + } + return Status::OK(); +} + +string CaseBuilder::NewName(const string& infix) { + return graph_->NewName(strings::StrCat(name_, "/", infix)); +} + +Status CaseBuilder::AddInput(Node* src, int src_output) { + Node* input; + NodeDebugInfo debug_info(*src); + // Colocate the Switch node with the `src` node. + // + // This is to avoid unnecessary Host<->Device copies between src and the + // _SwitchN node. This aligns with the implementation of legacy tf.cond in + // control_flow_ops.py. The legacy impl colocates the Switch with the + // input tensor which resets the device stack and forces the Switch to have + // the same device as the input node (if set) and sets the colocation _class + // attr. It also ignores the existing colocation constraints on the input node + // using colocate_with(ignore_existing=True). + TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "_SwitchN", + graph_->op_registry(), &debug_info) + .Input(src, src_output) + .Input(branch_index_) + .Device(src->requested_device()) + .Attr("_class", {src->name()}) + .Attr("num_outs", num_branches_) + .Finalize(graph_, &input)); + for (int b = 0; b < num_branches_; b++) { + branch_call_builders_[b].Input(input, b); + } + return Status::OK(); +} + +Status CaseBuilder::AddInputs() { + // Add input data edges. + std::vector edges; + TF_RETURN_IF_ERROR(case_op_->input_edges(&edges)); + // Start at index 1 as the first input is the branch index. + for (int i = 1; i < edges.size(); ++i) { + const Edge* e = edges[i]; + TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output())); + } + // Add input control edges. + for (const Edge* e : case_op_->in_edges()) { + if (e->IsControlEdge()) { + graph_->AddControlEdge(e->src(), control_predecessor_); + } + } + return Status::OK(); +} + +Status CaseBuilder::AddOutputs() { + // Construct the call nodes for each branch. + call_nodes_.resize(num_branches_, nullptr); + for (int b = 0; b < num_branches_; b++) { + TF_RETURN_IF_ERROR( + branch_call_builders_[b].Finalize(graph_, &call_nodes_[b])); + graph_->AddControlEdge(pivots_[b], call_nodes_[b]); + } + + // Merge the outputs from the N branches (all branches have matching outputs). + const int num_outputs = call_nodes_[0]->num_outputs(); + std::vector merges(num_outputs); + outputs_.resize(merges.size()); + for (int i = 0; i < num_outputs; ++i) { + std::vector merge_input; + merge_input.reserve(num_branches_); + for (int j = 0; j < num_branches_; j++) { + merge_input.emplace_back(call_nodes_[j], i); + } + TF_RETURN_IF_ERROR(NodeBuilder(NewName("merge"), "Merge", + graph_->op_registry(), &debug_info_) + .Input(merge_input) + .Device(case_op_->requested_device()) + .Finalize(graph_, &merges[i])); + outputs_[i] = NodeOut(merges[i], 0); + } + + // Add a Merge node that will be used as a control dependency source for the + // lowered output node. This Merge node will guarantee that lowered else/then + // function calls will be executed even if they do not have data outputs. + // + // Furthermore it will guarantee that all function side effects will be + // executed, if the function will be inlined into the graph. Having data + // outputs is not enough, because they might become unused after inlining. + // + // We will use this node to rewrite outgoing control edges from lowered 'Case' + // node. All data edges will read tensors directly from Merge nodes. + std::vector pivots(num_branches_); + for (int j = 0; j < num_branches_; j++) { + pivots[j] = NodeOut(pivots_[j]); + } + TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge", + graph_->op_registry(), &debug_info_) + .Input(pivots) + .ControlInputs(call_nodes_) + .Device(case_op_->requested_device()) + .Finalize(graph_, &branch_executed_node_)); + + TF_RETURN_IF_ERROR(BuildLoweredCaseOutput()); + + // Add outputs. + for (const Edge* e : case_op_->out_edges()) { + if (e->IsControlEdge()) { + graph_->AddControlEdge(branch_executed_node_, e->dst()); + } else { + // Feed the outputs directly from the merge nodes so that downstream ops + // can start before all the outputs have been computed. + graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input()); + } + } + return Status::OK(); +} + +Status CaseBuilder::BuildLoweredCaseOutput() { + // If outputs are empty, it means that we might have only output control + // edges (already connected to the `branch_executed_node`). Furthermore it's + // illegal to have an IdentityN with empty inputs. + // + // We still must keep lowered Case node as a valid source of control edges, + // because it might be a part of function control output set. + NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty() + ? NodeBuilder(name_, "IdentityN").Input(outputs_) + : NodeBuilder(name_, "NoOp"); + return builder.Device(case_op_->requested_device()) + .ControlInput(branch_executed_node_) + .Finalize(graph_, &lowered_case_output_); +} + +} // namespace + +Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib, + bool keep_node_fetchable) { + VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable + << "): " << SummarizeNode(*n); + const AttrValue* branches_attr = n->attrs().Find("branches"); + if (branches_attr == nullptr) { + return errors::InvalidArgument("branch functions missing"); + } + + int num_branches = branches_attr->list().func_size(); + std::vector branch_fn_names; + branch_fn_names.reserve(num_branches); + for (int b = 0; b < num_branches; b++) { + branch_fn_names.emplace_back(branches_attr->list().func(b).name()); + } + CaseBuilder cb(n, branch_fn_names, flib, keep_node_fetchable, g); + TF_RETURN_IF_ERROR(cb.CreatePivotNodes()); + TF_RETURN_IF_ERROR(cb.AddInputs()); + TF_RETURN_IF_ERROR(cb.AddOutputs()); + g->RemoveNode(n); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_case_op.h b/tensorflow/core/common_runtime/lower_case_op.h new file mode 100644 index 00000000000..fc46a1f34b6 --- /dev/null +++ b/tensorflow/core/common_runtime/lower_case_op.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes. +Status RewriteCaseNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib, + bool keep_node_fetchable); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ diff --git a/tensorflow/core/common_runtime/lower_case_op_test.cc b/tensorflow/core/common_runtime/lower_case_op_test.cc new file mode 100644 index 00000000000..ce34a21f0ca --- /dev/null +++ b/tensorflow/core/common_runtime/lower_case_op_test.cc @@ -0,0 +1,441 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/lower_functional_ops.h" + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +AttrValue FuncListAttr(const absl::Span names) { + AttrValue attr; + for (const char* name : names) { + attr.mutable_list()->add_func()->set_name(name); + } + return attr; +} + +SessionOptions SessionOptionsWithInlining() { + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_do_function_inlining(true); + return session_options; +} + +Status Rewrite(std::unique_ptr* graph) { + FunctionLibraryDefinition flib_def((*graph)->flib_def()); + GraphOptimizationPassOptions opt_options; + SessionOptions session_options = SessionOptionsWithInlining(); + opt_options.session_options = &session_options; + opt_options.graph = graph; + opt_options.flib_def = &flib_def; + LowerFunctionalOpsPass pass; + return pass.Run(opt_options); +} + +TEST(LowerCaseOpTest, Simple) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + // Add test functions for then and else branch. + FunctionDefLibrary f_lib_proto; + *(f_lib_proto.add_function()) = test::function::XTimesTwo(); + *(f_lib_proto.add_function()) = test::function::XTimesFour(); + *(f_lib_proto.add_function()) = test::function::XTimes16(); + + // Construct simple conditional that switches on `pred` and operates only on + // single input `A`. + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto)); + auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32); + auto branch_index = + ops::Placeholder(root.WithOpName("branch_index"), DT_INT32); + Node* written_if; + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + TF_ASSERT_OK( + NodeBuilder("case", "Case", &root.graph()->flib_def()) + .Input(branch_index.node()) + .Input(inputs) + .Attr("branches", + FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"})) + .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) + .Attr("Tout", {DT_INT32}) + .Finalize(root.graph(), &written_if)); + TF_ASSERT_OK(root.DoShapeInference(written_if)); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + // The input graph has no switch or merge nodes. + int node_called_case_count = 0; + for (const auto* op : graph->op_nodes()) { + ASSERT_FALSE(op->IsSwitch()); + ASSERT_FALSE(op->IsMerge()); + if (op->name() == "case") { + ++node_called_case_count; + } + } + ASSERT_EQ(node_called_case_count, 1); + + TF_ASSERT_OK(Rewrite(&graph)); + + // Verify the resultant graph has switch and merge nodes, and a node called + // `if` (but not If nodes). + int switch_count = 0; + int merge_count = 0; + node_called_case_count = 0; + for (const auto* op : graph->op_nodes()) { + if (op->IsSwitch()) { + ++switch_count; + } + if (op->IsMerge()) { + ++merge_count; + } + ASSERT_NE(op->type_string(), "Case"); + if (op->name() == "case") { + ++node_called_case_count; + } + } + // One switch for predicate and one for input (A). + ASSERT_EQ(switch_count, 2); + // One merge for the single output value of then and else, and one more merge + // to enforce then and else function call execution (`branch_executed` node). + ASSERT_EQ(merge_count, 2); + ASSERT_EQ(node_called_case_count, 1); + + // Verify execution. + ClientSession session(root, SessionOptionsWithInlining()); + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(-1)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 160); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(0)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 20); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(1)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 40); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(2)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 160); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(20)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 160); + } +} + +TEST(LowerCaseOpTest, BranchFunctionsWithoutOutputs) { + using ::tensorflow::test::function::GDef; + using ::tensorflow::test::function::NDef; + using FDH = ::tensorflow::FunctionDefHelper; + + // Wrap AssignAddVariable + Const into a function. + const auto assign_add = [](const string& fn_name, int v) { + const Tensor tensor = test::AsScalar(v); + return FDH::Create( + fn_name, {"v: resource"}, {}, {}, + { + {{"c"}, "Const", {}, {{"value", tensor}, {"dtype", DT_INT32}}}, + {{"upd"}, + "AssignAddVariableOp", + {"v", "c:output"}, + {{"dtype", DT_INT32}}}, + }, + /*ret_def=*/{}, + /*control_ret_def=*/{{"side_effects", "upd"}}); + }; + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + // Add test functions for then and else branch. + FunctionDefLibrary f_lib_proto; + *(f_lib_proto.add_function()) = assign_add("AddOne", 1); + *(f_lib_proto.add_function()) = assign_add("AddTwo", 2); + *(f_lib_proto.add_function()) = assign_add("AddTen", 10); + + // Construct a graph to represent following program: + // + // (branch_index: int32, initial_val: int32) -> (int32) { + // var = Variable(initial_value) + // switch (branch_index) { + // case 0: + // var += 1; break; # AddOne function call + // case 1: + // var += 2; break; # AddTwo function call + // default: + // var += 10; break; # AddTen function call + // } + // return var + // } + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto)); + + auto branch_index = + ops::Placeholder(root.WithOpName("branch_index"), DT_INT32); + auto initial_val = ops::Placeholder(root.WithOpName("initial_val"), DT_INT32); + + auto var = ops::VarHandleOp(root.WithOpName("var"), DT_INT32, {}); + auto init = ops::AssignVariableOp(root.WithOpName("init"), var, initial_val); + + Node* case_node; + std::vector inputs({NodeBuilder::NodeOut(var.node())}); + TF_ASSERT_OK( + NodeBuilder("case", "Case", &root.graph()->flib_def()) + .Input(branch_index.node()) + .Input(inputs) + .ControlInput(init.operation.node()) + .Attr("branches", FuncListAttr({"AddOne", "AddTwo", "AddTen"})) + .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) + .Attr("Tout", DataTypeSlice{}) + .Finalize(root.graph(), &case_node)); + + auto read = ops::ReadVariableOp( + root.WithOpName("read").WithControlDependencies(Output(case_node)), var, + DT_INT32); + + TF_ASSERT_OK(root.DoShapeInference(case_node)); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(Rewrite(&graph)); + + // Verify the resultant graph has switch, merge and function call nodes. + int switch_count = 0; + int merge_count = 0; + int node_called_case_count = 0; + for (const auto* op : graph->op_nodes()) { + if (op->IsSwitch()) ++switch_count; + if (op->IsMerge()) ++merge_count; + if (op->name() == "case") ++node_called_case_count; + ASSERT_NE(op->type_string(), "Case"); + } + // One switch for predicate and one for input (A). + ASSERT_EQ(switch_count, 2); + // One merge for the else/then branch (`branch_executed` node). + ASSERT_EQ(merge_count, 1); + // We keep a NoOp with the same name as original If node. + ASSERT_EQ(node_called_case_count, 1); + + // Verify execution. + ClientSession session(root, SessionOptionsWithInlining()); + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(-5)); + feeds.emplace(Output(initial_val.node()), Input::Initializer(10)); + + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 20); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(0)); + feeds.emplace(Output(initial_val.node()), Input::Initializer(10)); + + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 11); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(1)); + feeds.emplace(Output(initial_val.node()), Input::Initializer(10)); + + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 12); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(2)); + feeds.emplace(Output(initial_val.node()), Input::Initializer(10)); + + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 20); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(31)); + feeds.emplace(Output(initial_val.node()), Input::Initializer(10)); + + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(read)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 20); + } +} + +TEST(LowerCaseOpTest, DoNotInlineLoweredFunction) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + FunctionDef x_times_two = test::function::XTimesTwo(); + FunctionDef x_times_four = test::function::XTimesFour(); + FunctionDef x_times_16 = test::function::XTimes16(); + + // Case branches can't be inlined. + (*x_times_two.mutable_attr())["_noinline"].set_b(true); + (*x_times_four.mutable_attr())["_noinline"].set_b(true); + (*x_times_16.mutable_attr())["_noinline"].set_b(true); + + // Add test functions for the branches. + FunctionDefLibrary f_lib_proto; + *(f_lib_proto.add_function()) = x_times_two; + *(f_lib_proto.add_function()) = x_times_four; + *(f_lib_proto.add_function()) = x_times_16; + + // Construct simple conditional that switches on `pred` and operates only on + // single input `A`. + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto)); + auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32); + auto branch_index = + ops::Placeholder(root.WithOpName("branch_index"), DT_INT32); + Node* written_case; + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + TF_ASSERT_OK( + NodeBuilder("case", "Case", &root.graph()->flib_def()) + .Input(branch_index.node()) + .Input(inputs) + .Attr("branches", + FuncListAttr({"XTimesTwo", "XTimesFour", "XTimes16"})) + .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) + .Attr("Tout", {DT_INT32}) + .Finalize(root.graph(), &written_case)); + TF_ASSERT_OK(root.DoShapeInference(written_case)); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(Rewrite(&graph)); + + // Verify that Case node was lowered but branch functions were not inlined. + int x_times_two_count = 0; + int x_times_four_count = 0; + int x_times_16_count = 0; + + for (const auto* op : graph->op_nodes()) { + if (op->type_string() == x_times_two.signature().name()) { + x_times_two_count++; + } + if (op->type_string() == x_times_four.signature().name()) { + x_times_four_count++; + } + if (op->type_string() == x_times_16.signature().name()) { + x_times_16_count++; + } + ASSERT_NE(op->type_string(), "Case"); + } + + // One function for each branch. + ASSERT_EQ(x_times_two_count, 1); + ASSERT_EQ(x_times_four_count, 1); + ASSERT_EQ(x_times_16_count, 1); + + // Verify execution. + ClientSession session(root, SessionOptionsWithInlining()); + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(-2)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 160); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(0)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 20); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(1)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 40); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(2)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 160); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(branch_index.node()), Input::Initializer(31)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_case)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 160); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc index 5b7276266fd..45f48a6fb1f 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/lower_case_op.h" #include "tensorflow/core/common_runtime/lower_function_call_op.h" #include "tensorflow/core/common_runtime/lower_if_op.h" #include "tensorflow/core/common_runtime/lower_while_op.h" @@ -114,13 +115,13 @@ Status LowerFunctionalOpsPass::Run( ? *keep_lowered_nodes_fetchable_ : !HasArgsOrRetvals(*g); - // Lower all If and While ops that have the `kLowerUsingSwitchMergeAttr` attr - // set and inlines all function calls into the graph. + // Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr` + // attr set and inline all function calls into the graph. // We start at `i` = 2 to skip the source and sink nodes. - // Note that `g->num_node_ids()` may change in the for body if a matching If - // or While node is lowered. Since new graph nodes are always added to the - // end of the list of nodes it is ensured that nested If/While nodes will be - // lowered as well. + // Note that `g->num_node_ids()` may change in the for body if a matching If, + // Case, While node is lowered. Since new graph nodes are always added to the + // end of the list of nodes it is ensured that nested If/Case/While nodes will + // be lowered as well. for (int i = 2; i < g->num_node_ids(); ++i) { Node* n = g->FindNodeId(i); if (n == nullptr) continue; // deleted node @@ -139,6 +140,9 @@ Status LowerFunctionalOpsPass::Run( if (n->type_string() == "If") { TF_RETURN_IF_ERROR( RewriteIfNode(n, g, *flib_def, keep_lowered_nodes_fetchable)); + } else if (n->type_string() == "Case") { + TF_RETURN_IF_ERROR( + RewriteCaseNode(n, g, *flib_def, keep_lowered_nodes_fetchable)); } else if (n->type_string() == "While") { TF_RETURN_IF_ERROR( RewriteWhileNode(n, g, *flib_def, keep_lowered_nodes_fetchable)); diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index 43418f166af..ec37d72faab 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/lower_if_op.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -89,9 +89,9 @@ class CondBuilder { Node* then_call_node_; Node* else_call_node_; // Merge node that has inputs from [pivot_t, pivot_f] and control edges from - // [^then_call_node_, ^else_call_node_]. This node will guarantee that if - // then/else branch functions do not have outputs, they still will be executed - // for the side effects. + // [^then_call_node_, ^else_call_node_]. This node will guarantee that even + // when then/else branch functions do not have outputs, they still will be + // executed for the side effects. Node* branch_executed_node_; Graph* graph_; const FunctionLibraryDefinition& flib_; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 6574f3bf622..ee6962d7abc 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -58,6 +58,7 @@ const std::unordered_map& Node::kNodeClassTable = *new std::unordered_map({ // Keep in same order as NodeClass values REF_CLASS("Switch", NC_SWITCH), + REF_CLASS("_SwitchN", NC_SWITCH), REF_CLASS("Merge", NC_MERGE), REF_CLASS("Enter", NC_ENTER), REF_CLASS("Exit", NC_EXIT), diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index a414fc63749..3479246eeb2 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -150,7 +150,8 @@ bool IsControlFlow(const NodeDef& node) { node.op() == "LoopCond" || node.op() == "Merge" || node.op() == "NextIteration" || - node.op() == "Switch"; + node.op() == "Switch" || + node.op() == "_SwitchN"; // clang-format on } @@ -523,7 +524,7 @@ bool IsSum(const NodeDef& node) { return node.op() == "Sum"; } bool IsSwitch(const NodeDef& node) { const auto& op = node.op(); - return op == "Switch" || op == "RefSwitch"; + return op == "_SwitchN" || op == "Switch" || op == "RefSwitch"; } bool IsSymbolicGradient(const NodeDef& node) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index fb6aaf7082e..c607da6ab87 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1281,7 +1281,7 @@ Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) { node_map_->AddOutput(const_index->name(), output->name()); } else { // This is a control dependency (or an invalid edge since the - // merge node has only 2 inputs): preserve them. + // merge node has only 2 outputs): preserve them. } } } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 1c2908ee9d5..16f8d440c5d 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/lower_case_op.h" #include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/common_runtime/lower_if_op.h" #include "tensorflow/core/common_runtime/lower_while_op.h" @@ -1157,6 +1158,8 @@ Status InlineFunctionCalls(const GrapplerItem& item, if (n->type_string() == "If") { TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false)); + } else if (n->type_string() == "Case") { + TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), flib_def, false)); } else if (n->type_string() == "While") { TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false)); } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 02675328b71..4f3028f49cc 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -195,6 +195,7 @@ std::set GetOpsFormatAgnostic() { "StridedSlice", "StridedSliceGrad", "Switch", + "_SwitchN", "Tile", "TruncateDiv", "TruncateMod", @@ -1943,7 +1944,15 @@ class SwitchProcessor : public AgnosticNodeProcessor { : AgnosticNodeProcessor(opt_cxt) {} protected: - std::set GetOutputPos() const override { return {0, 1}; } + std::set GetOutputPos() const override { + std::set output_pos; + const int num_outs = + node_->attr().count("num_outs") ? node_->attr().at("num_outs").i() : 2; + for (int i = 0; i < num_outs; i++) { + output_pos.insert(i); + } + return output_pos; + } }; class TileProcessor : public AgnosticNodeProcessor { diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 94209aabea2..3ffc6ad2d46 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -460,7 +460,7 @@ std::vector GetStackPushNodesToConvert( const std::unordered_set op_types_to_traverse( {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch", - "Identity", "RefIdentity"}); + "_SwitchN", "Identity", "RefIdentity"}); const auto is_op_to_traverse = [&](const NodeDef* node) -> bool { return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end(); }; @@ -753,6 +753,9 @@ Status LoopOptimizer::RemoveDeadBranches( if (!IsSwitch(node)) { continue; } + if (node.op() == "_SwitchN") { // _SwitchN not used in loop control flow. + continue; + } if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { continue; } @@ -798,8 +801,8 @@ Status LoopOptimizer::RemoveDeadBranches( if (IsMerge(*dead.node)) { const int num_data_inputs = dead.node->attr().at("N").i(); if (num_data_inputs > 2) { - // This never happens in practice, so we'll just skip these to - // simplify the code for now. + // This can happen with _SwitchN/Merge (Case lowering). We skip these + // to simplify the code for now. found_node_to_preserve = true; break; } @@ -877,11 +880,15 @@ Status LoopOptimizer::RemoveDeadBranches( } // Remove dead data input. const std::set& dead_inputs = itr.second; + CHECK_LE(dead_inputs.size(), 1); + // (This loop would delete >1 items possibly in the wrong order.) for (int index : dead_inputs) { dead_node->mutable_input()->DeleteSubrange(index, 1); } - // Turn Merge into Identity only if we deleted data inputs. + // Turn Merge into Identity only if we deleted the other data input. if (!dead_inputs.empty()) { + const int num_data_inputs = dead_node->attr().at("N").i(); + CHECK_EQ(num_data_inputs, dead_inputs.size() + 1); dead_node->set_op("Identity"); dead_node->mutable_attr()->erase("N"); } diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index c0981805bbe..75a2ea5a5d0 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -39,12 +39,30 @@ void SwitchOp::Compute(OpKernelContext* context) { } } +void SwitchNOp::Compute(OpKernelContext* context) { + const Tensor& output_index_t = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_index_t.shape()), + errors::InvalidArgument("The second input must be a scalar, " + "but it has shape ", + output_index_t.shape().DebugString())); + int output_index = output_index_t.scalar()(); + if (output_index < 0 || output_index >= num_outputs()) { + output_index = num_outputs() - 1; + } + context->set_output(output_index, context->input(0)); +} + #define REGISTER_CPU_SWITCH(type) \ REGISTER_KERNEL_BUILDER(Name("Switch") \ .Device(DEVICE_CPU) \ .HostMemory("pred") \ .TypeConstraint("T"), \ - SwitchOp) + SwitchOp) \ + REGISTER_KERNEL_BUILDER(Name("_SwitchN") \ + .Device(DEVICE_CPU) \ + .HostMemory("output_index") \ + .TypeConstraint("T"), \ + SwitchNOp) #define REGISTER_CPU_REF_SWITCH(type) \ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ @@ -58,7 +76,12 @@ void SwitchOp::Compute(OpKernelContext* context) { .Device(DEVICE_GPU) \ .HostMemory("pred") \ .TypeConstraint("T"), \ - SwitchOp) + SwitchOp) \ + REGISTER_KERNEL_BUILDER(Name("_SwitchN") \ + .Device(DEVICE_GPU) \ + .HostMemory("output_index") \ + .TypeConstraint("T"), \ + SwitchNOp) #define REGISTER_GPU_REF_SWITCH(type) \ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ @@ -96,7 +119,14 @@ TF_CALL_variant(REGISTER_GPU_SWITCH); .HostMemory("output_false") \ .HostMemory("output_true") \ .TypeConstraint("T"), \ - SwitchOp) + SwitchOp) \ + REGISTER_KERNEL_BUILDER(Name("_SwitchN") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output_index") \ + .HostMemory("outputs") \ + .TypeConstraint("T"), \ + SwitchNOp) #define REGISTER_GPU_HOST_REF_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h index c607fcf298f..37d561ca98b 100644 --- a/tensorflow/core/kernels/control_flow_ops.h +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -46,6 +46,21 @@ class SwitchOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp); }; +// An n-way switch op has two inputs and N outputs. It forwards the value of +// Input:0 to the output specified by Input:1. Input:1 is an integer tensor. +// Input:0 is forwarded to output:0 if Input:1 is 0, to output:1 if 1, and so +// forth. If Input:1 is <0 or >=num_outputs(), Input:0 is forwarded to +// output:num_outputs()-1. +class SwitchNOp : public OpKernel { + public: + explicit SwitchNOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~SwitchNOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(SwitchNOp); +}; + // A merge op has n inputs and two outputs. It forwards the value of the // first input that becomes available to its first output, and the // index of the first input to its second output. diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc index b8028291b40..c05a73bfd8f 100644 --- a/tensorflow/core/ops/control_flow_ops.cc +++ b/tensorflow/core/ops/control_flow_ops.cc @@ -24,6 +24,7 @@ using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- namespace { + Status SwitchShape(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); @@ -39,6 +40,27 @@ Status SwitchShape(InferenceContext* c) { } return Status::OK(); } + +Status SwitchNShape(InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + ShapeHandle out = c->input(0); + int num_outs; + TF_RETURN_IF_ERROR(c->GetAttr("num_outs", &num_outs)); + for (int i = 0; i < num_outs; i++) { + c->set_output(i, out); + } + + // Handle resource shape / dtype. + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + for (int i = 0; i < num_outs; i++) { + c->set_output_handle_shapes_and_types(i, *handle_data); + } + } + return Status::OK(); +} + } // namespace REGISTER_OP("Switch") @@ -58,6 +80,14 @@ REGISTER_OP("RefSwitch") .SetAllowsUninitializedInput() .SetShapeFn(SwitchShape); +REGISTER_OP("_SwitchN") + .Input("data: T") + .Input("output_index: int32") + .Output("outputs: num_outs * T") + .Attr("num_outs: int >= 1") + .Attr("T: type") + .SetShapeFn(SwitchNShape); + // -------------------------------------------------------------------------- REGISTER_OP("RefSelect") .Input("index: int32") diff --git a/tensorflow/core/ops/control_flow_ops_test.cc b/tensorflow/core/ops/control_flow_ops_test.cc index 2c0736c8bcc..e3175efa44d 100644 --- a/tensorflow/core/ops/control_flow_ops_test.cc +++ b/tensorflow/core/ops/control_flow_ops_test.cc @@ -49,6 +49,27 @@ TEST(ControlFlowOpsTest, Merge_ShapeFn) { INFER_OK(op, "[2,1];[2,1];[2,1]", "in0;[]"); } +TEST(ControlFlowOpsTest, SwitchN_ShapeFn) { + ShapeInferenceTestOp op("_SwitchN"); + + int n = 5; + TF_ASSERT_OK(NodeDefBuilder("test", "_SwitchN") + .Input({"d", 0, DT_FLOAT}) + .Input({"bi", 0, DT_INT32}) + .Attr("num_outs", n) + .Finalize(&op.node_def)); + + // Non-scalar output_index. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[2]"); + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1]"); + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[?]"); + // The second input should always be scalar. Outputs are 5x the first input. + INFER_OK(op, "?;?", "in0;in0;in0;in0;in0"); + INFER_OK(op, "[2,?];?", "in0;in0;in0;in0;in0"); + INFER_OK(op, "[2,?];[]", "in0;in0;in0;in0;in0"); + INFER_OK(op, "[2,3];[]", "in0;in0;in0;in0;in0"); +} + TEST(ControlFlowOpsTest, RefSelect_ShapeFn) { ShapeInferenceTestOp op("RefSelect"); diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 29e06534b72..d1751f9b1e5 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -1177,6 +1177,84 @@ class FunctionalOpsCaseTest(test.TestCase): self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default + @test_util.run_deprecated_v1 + @test_util.disable_xla("Don't lower for XLA") + def testSkipEagerCaseLoweringPreservesNameForFetch(self): + for use_gpu in (True, False): + def Run(branch, x, fetch_by_name, use_gpu=use_gpu): + with ops.Graph().as_default() as g: + @function.Defun(dtypes.float32) + def two(x): + return -1, x * 2 + + @function.Defun(dtypes.float32) + def three(x): + return 0, x * 3 + + @function.Defun(dtypes.float32) + def four(x): + return 1, x * 4 + + outputs = gen_functional_ops.case(branch, input=[x], + Tout=[dtypes.int32, dtypes.float32], + branches=[two, three, four], + name="my_case") + + # `outputs` is the list of output tensors of the Case op. We + # arbitrarily choose the 0th tensor to get the Case op and set the + # lowering attribute on it. + outputs[0].op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + outputs = array_ops.identity_n(outputs) + with self.session(graph=g, use_gpu=use_gpu) as sess: + return sess.run("my_case:1" if fetch_by_name else outputs[1]) + + self.assertAllEqual(2 * 1., Run(0, 1., False)) + self.assertAllEqual(2 * 1., Run(0, 1., True)) + self.assertAllEqual(3 * 7., Run(1, 7., False)) + self.assertAllEqual(3 * 7., Run(1, 7., True)) + self.assertAllEqual(4 * -3., Run(2, -3., False)) + self.assertAllEqual(4 * -3., Run(2, -3., True)) + self.assertAllEqual(4 * -4., Run(7, -4., False)) # >= N default + self.assertAllEqual(4 * -4., Run(7, -4., True)) # >= N default + self.assertAllEqual(4 * -5., Run(-1, -5., False)) # <0 default + self.assertAllEqual(4 * -5., Run(-1, -5., True)) # <0 default + + @test_util.disable_xla("Don't lower for XLA") + def testCaseLowering(self): + for use_gpu in (True, False): + @eager_function.defun + def Run(branch, x): + @function.Defun(dtypes.float32) + def two(x): + return -1, x * 2 + + @function.Defun(dtypes.float32) + def three(x): + return 0, x * 3 + + @function.Defun(dtypes.float32) + def four(x): + return 1, x * 4 + + outputs = gen_functional_ops.case(branch, input=[x], + Tout=[dtypes.int32, dtypes.float32], + branches=[two, three, four]) + + # `outputs` is the list of output tensors of the Case op. We + # arbitrarily choose the 0th tensor to get the Case op and set the + # lowering attribute on it. + outputs[0].op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + outputs = array_ops.identity_n(outputs) + return outputs[1] + + with ops.device(test.gpu_device_name() if use_gpu else "CPU:0"): + self.assertAllEqual(2 * 1., self.evaluate(Run(0, 1.))) + self.assertAllEqual(3 * 7., self.evaluate(Run(1, 7.))) + self.assertAllEqual(4 * -3., self.evaluate(Run(2, -3.))) + self.assertAllEqual(4 * -4., self.evaluate(Run(7, -4.))) # >=N default + self.assertAllEqual(4 * -5., self.evaluate(Run(-1, -5.))) # <0 default if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 69364d70477..1b975cd8590 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -510,7 +510,7 @@ def _make_output_composite_tensors_match(op_type, branch_graphs): raise TypeError( "Cannot reconcile {op_name} {output_idx}-th outputs:\n" " outputs from all branches: {outputs}".format( - op_name="tf.cond" if op_type == _COND else "tf.case", + op_name="tf.cond" if op_type == _COND else "tf.switch_case", output_idx=output_idx, outputs=branch_outs)) @@ -534,9 +534,9 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs): for output_idx, branch_outs in enumerate( zip(*branch_outputs_flat_with_composites)): if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1: - raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n" + raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" " branches returned: {outputs}".format( - op_name="tf.cond" if op_type == _COND else "tf.case", + op_name="cond" if op_type == _COND else "switch_case", output_idx=output_idx, outputs=branch_outs)) if isinstance(branch_outs[0], ops.IndexedSlices): @@ -632,7 +632,7 @@ def _check_same_outputs(op_type, graphs): b0_name="true_fn" if op_type == _COND else "branches[0]", bn_name=("false_fn" if op_type == _COND else "branches[{}]".format(branch_idx)), - op_name="tf.cond" if op_type == _COND else "tf.case", + op_name="tf.cond" if op_type == _COND else "tf.switch_case", b0_out=graphs[0].structured_outputs, bn_out=graphs[branch_idx].structured_outputs, detail=error_detail)) @@ -817,8 +817,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"): @ops.RegisterGradient("Case") def _CaseGrad(op, *grads): # pylint: disable=invalid-name - """The gradient of a Case op produced (w/ branch_index) by tf.case.""" - # Get the if operator (this logic handles the case where op is a MockOp) + """The gradient of a Case op produced by tf.switch_case.""" + # Get the Case operator (this logic handles the case where op is a MockOp) case_op = op.outputs[0].op branch_graphs = _get_func_graphs(case_op) assert branch_graphs @@ -892,7 +892,7 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name _make_output_composite_tensors_match(_CASE, branch_grad_graphs) outputs = _build_case(case_op.inputs[0], branch_grad_graphs, - branches_grad_inputs) + branches_grad_inputs, name="gradient") # The predicate has no gradient. return [None] + outputs @@ -935,14 +935,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None): output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), name=name) - # TODO(b/110167197) this requires Case to have at least 1 output + # TODO(b/110167197): this requires Case to have at least 1 output case_op = tensors[0].op - # TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode. - # util.maybe_set_lowering_attr(case_op) + util.maybe_set_lowering_attr(case_op) util.maybe_propagate_compile_time_consts_in_xla(case_op) # Return identities for each output of the Case op, rather than the output of - # the Case op directly. This makes pruning work if the output of select_case() + # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN # outputs, which if fetched will cause all ops in the taken branch to be run # (since it takes all merge ops as input). After lowering, each output diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index f67d785fc0e..c0fa1af6ec2 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -19,11 +19,14 @@ from __future__ import division from __future__ import print_function import collections +from absl.testing import parameterized import numpy as np -from tensorflow.python import tf2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import tf2 +from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -39,7 +42,9 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope @@ -933,7 +938,10 @@ class DataTypesTest(test_util.TensorFlowTestCase): @test_util.run_all_in_graph_and_eager_modes -class IndexedCaseTest(test_util.TensorFlowTestCase): +class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase): + + def make_name(self): + return self.id().split(".")[-1].replace("(", "_").replace(")", "") def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self): nbranches = 5 @@ -947,53 +955,55 @@ class IndexedCaseTest(test_util.TensorFlowTestCase): case_out = control_flow_ops.switch_case(branch_index, branches) self.assertEqual(bi * 10, self.evaluate(case_out)) - def testCase(self): + @parameterized.parameters((0,), (2,), (3,)) + def testCase(self, bi): nbranches = 5 def make_func(bi): return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) branches = [(i, make_func(i)) for i in range(nbranches)] - for bi in 0, 2, 3: - branch_index = array_ops.placeholder_with_default(bi, []) - case_out = control_flow_ops.switch_case(branch_index, branches) - self.assertEqual(bi * 10., self.evaluate(case_out)) + branch_index = array_ops.placeholder_with_default(bi, []) + case_out = control_flow_ops.switch_case( + branch_index, branches, name=self.make_name()) + self.assertEqual(bi * 10., self.evaluate(case_out)) - def testCase_withDefault(self): + @parameterized.parameters((-1,), (2,), (4,), (5,), (6,)) + def testCase_withDefault(self, bi): nbranches = 5 def make_func(bi): return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) branches = [(i, make_func(i)) for i in range(nbranches)] - for bi in -1, 2, nbranches: - branch_index = array_ops.placeholder_with_default(bi, []) - case_out = control_flow_ops.switch_case( - branch_index, branches, default=make_func(6)) - if bi < 0 or bi >= nbranches: - expected = 60. - else: - expected = bi * 10. - self.assertEqual(expected, self.evaluate(case_out)) + branch_index = array_ops.placeholder_with_default(bi, []) + case_out = control_flow_ops.switch_case( + branch_index, branches, default=make_func(6), name=self.make_name()) + if bi < 0 or bi >= nbranches: + expected = 60. + else: + expected = bi * 10. + self.assertEqual(expected, self.evaluate(case_out)) - def testCase_dictWithDefault(self): + @parameterized.parameters((-1,), (0,), (3,), (5,)) + def testCase_dictWithDefault(self, bi): nbranches = 5 def make_func(bi): return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) branches = {i: make_func(i) for i in range(nbranches)} - for bi in -1, 0, 3, nbranches: - branch_index = array_ops.placeholder_with_default(bi, []) - case_out = control_flow_ops.switch_case( - branch_index, branches, default=make_func(6)) - if bi < 0 or bi >= nbranches: - expected = 60. - else: - expected = bi * 10. - self.assertEqual(expected, self.evaluate(case_out)) + branch_index = array_ops.placeholder_with_default(bi, []) + case_out = control_flow_ops.switch_case( + branch_index, branches, default=make_func(6), name=self.make_name()) + if bi < 0 or bi >= nbranches: + expected = 60. + else: + expected = bi * 10. + self.assertEqual(expected, self.evaluate(case_out)) - def testCase_gradient(self): + @parameterized.parameters((-1,), (1,), (4,), (5,)) + def testCase_gradient(self, bi): nbranches = 5 inputs = [ array_ops.constant(float(bi), name="br{}_in".format(bi)) @@ -1005,22 +1015,22 @@ class IndexedCaseTest(test_util.TensorFlowTestCase): branches = {bi: make_func(bi) for bi in range(nbranches)} - for bi in -1, 1, 4, nbranches: - branch_index = array_ops.placeholder_with_default(bi, []) - with backprop.GradientTape() as tape: - for x in inputs: - tape.watch(x) - case_out = control_flow_ops.switch_case(branch_index, branches) - out_grad = 3. - actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) - expected_grads = [None if context.executing_eagerly() else 0.] * nbranches - used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi - expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx - self.assertEqual(len(expected_grads), len(actual_grads)) - for expected, actual in zip(expected_grads, actual_grads): - self.assertEqual(expected, self.evaluate(actual)) + branch_index = array_ops.placeholder_with_default(bi, []) + with backprop.GradientTape() as tape: + for x in inputs: + tape.watch(x) + case_out = control_flow_ops.switch_case(branch_index, branches) + out_grad = 3. + actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) + expected_grads = [None if context.executing_eagerly() else 0.] * nbranches + used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi + expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx + self.assertEqual(len(expected_grads), len(actual_grads)) + for expected, actual in zip(expected_grads, actual_grads): + self.assertEqual(expected, self.evaluate(actual)) - def testCase_gradient_diffShapedIntermediates(self): + @parameterized.parameters((-2,), (2,), (5,)) + def testCase_gradient_diffShapedIntermediates(self, bi): nbranches = 5 inputs = [ array_ops.constant( @@ -1038,34 +1048,105 @@ class IndexedCaseTest(test_util.TensorFlowTestCase): branches = {bi: make_func(bi) for bi in range(nbranches)} - for bi in -1, 2, nbranches: - branch_index = array_ops.placeholder_with_default(bi, []) - with backprop.GradientTape() as tape: - for x in inputs: - tape.watch(x) - case_out = control_flow_ops.switch_case(branch_index, branches) - out_grad = 3. - actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) - used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi - expected_grads = [] - for input_idx in range(nbranches): - if used_bi == input_idx: - with backprop.GradientTape() as tape: - tape.watch(inputs[used_bi]) - y = make_func(used_bi)() - expected_grads.append( - self.evaluate( - tape.gradient(y, inputs[used_bi], output_gradients=out_grad))) - else: - expected_grads.append(None if context.executing_eagerly() else [0.] * - (input_idx + 1)) + branch_index = array_ops.placeholder_with_default(bi, []) + with backprop.GradientTape() as tape: + for x in inputs: + tape.watch(x) + case_out = control_flow_ops.switch_case( + branch_index, branches, name=self.make_name()) + out_grad = 3. + actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) + used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi + expected_grads = [] + for input_idx in range(nbranches): + if used_bi == input_idx: + with backprop.GradientTape() as tape: + tape.watch(inputs[used_bi]) + y = make_func(used_bi)() + expected_grads.append( + self.evaluate( + tape.gradient(y, inputs[used_bi], output_gradients=out_grad))) + else: + expected_grads.append(None if context.executing_eagerly() else [0.] * + (input_idx + 1)) - self.assertEqual(len(expected_grads), len(actual_grads)) - for expected, actual in zip(expected_grads, actual_grads): - if expected is None: - self.assertIsNone(actual) - else: - self.assertAllEqual(expected, self.evaluate(actual)) + self.assertEqual(len(expected_grads), len(actual_grads)) + for expected, actual in zip(expected_grads, actual_grads): + if expected is None: + self.assertIsNone(actual) + else: + self.assertAllEqual(expected, self.evaluate(actual)) + + @test_util.run_gpu_only + @test_util.disable_xla("Wants RunMetadata") + def testParallelExecution(self): + """Verify disjoint branches across while iterations are run in parallel.""" + with ops.Graph().as_default() as g: + nbranches = 7 + matrices = array_ops.unstack( # Ensure all are ready before while. + array_ops.matrix_diag( + random_ops.random_uniform([nbranches, 8, 512]) + 1e-3)) + + def make_branch(i, mat, name): + def branch_fn(): + next_i = i + 1 + with ops.device("gpu:0"): + return next_i, math_ops.reduce_sum( + linalg_ops.cholesky(mat, name=name + "_Cholesky")) + return branch_fn + + def make_branches(i): + return [make_branch(i, matrices[bi], "br{}".format(bi)) + for bi in range(nbranches)] + + def cond(i, _): + return i < nbranches + + def body(i, result): + with ops.device("cpu:0"): + next_i, branch_out = control_flow_ops.switch_case(i, make_branches(i)) + return next_i, result + branch_out + + _, result = control_flow_ops.while_loop(cond, body, [0, 0.]) + + run_metadata = config_pb2.RunMetadata() + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + config = config_pb2.ConfigProto( + allow_soft_placement=False, log_device_placement=True) + + with session.Session(config=config, graph=g) as sess: + _ = sess.run(result, options=run_options, run_metadata=run_metadata) + chol_node_stats = [] + for dev_stats in run_metadata.step_stats.dev_stats: + for node_stats in dev_stats.node_stats: + if (node_stats.node_name.endswith("Cholesky") and + node_stats.all_start_nanos > 0): + chol_node_stats.append(node_stats) + + self.assertLen(chol_node_stats, nbranches) + + chol_node_stats = sorted(chol_node_stats, key=lambda stats: stats.node_name) + op_start_nanos = [ + stats.all_start_nanos for stats in chol_node_stats + ] + op_end_nanos = [ + stats.all_start_nanos + stats.op_end_rel_nanos + for stats in chol_node_stats + ] + + def overlap(range1, range2): + s1, e1 = range1 + s2, e2 = range2 + if s1 < s2: + return 0 if s2 > e1 else e1 - s2 + return 0 if s1 > e2 else e2 - s1 + + timespans = list(zip(op_start_nanos, op_end_nanos)) + overlaps_chol0 = [overlap(timespans[0], r2) for r2 in timespans[1:]] + # There are nbranches-1 overlaps, sometimes all nonzero, but we + # conservatively check for at least one here, to avoid test flakiness. + self.assertGreater(np.count_nonzero(overlaps_chol0), 0) def testCase_validateIndicesContiguous(self):