From 2de551ba87cde98325a5d6a3c3b3c7d092af4376 Mon Sep 17 00:00:00 2001 From: Yanhua Sun Date: Thu, 28 May 2020 09:51:29 -0700 Subject: [PATCH] Add a function to dynamic choose and execute the proper implementation based on underlying device placement PiperOrigin-RevId: 313605766 Change-Id: I877b684dcef782b375df0504c0250acd9e808ce9 --- .../base_api/api_def_DeviceIndex.pbtxt | 5 + tensorflow/core/grappler/optimizers/BUILD | 1 + .../optimizers/implementation_selector.cc | 85 ++++++++- .../optimizers/implementation_selector.h | 87 +++++++++- .../implementation_selector_test.cc | 162 ++++++++++++++++++ tensorflow/core/kernels/functional_ops.cc | 32 ++++ tensorflow/core/ops/functional_ops.cc | 6 + tensorflow/python/ops/cond_v2.py | 17 +- tensorflow/python/ops/control_flow_ops.py | 57 +++++- .../python/ops/control_flow_ops_test.py | 86 ++++++++++ tensorflow/python/ops/control_flow_util_v2.py | 17 +- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 13 files changed, 550 insertions(+), 13 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt new file mode 100644 index 00000000000..9a4e5abd110 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DeviceIndex.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "DeviceIndex" + visibility: HIDDEN + summary: "Return the index of device the op runs." +} diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 030064e49fb..7432e2d54ea 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -1062,6 +1062,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/optimizers/implementation_selector.cc b/tensorflow/core/grappler/optimizers/implementation_selector.cc index 9c4f74d7268..2b0a27aaa2d 100644 --- a/tensorflow/core/grappler/optimizers/implementation_selector.cc +++ b/tensorflow/core/grappler/optimizers/implementation_selector.cc @@ -17,9 +17,12 @@ limitations under the License. #include +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" @@ -36,6 +39,11 @@ limitations under the License. namespace tensorflow { namespace grappler { +constexpr char kConstOp[] = "Const"; +constexpr char kCaseOp[] = "Case"; +constexpr char kDeviceIndexOp[] = "DeviceIndex"; + +// TODO(b/157615690): clean up function implementation swap code. // The overall idea for the function swap is like below: // ----------- ----------- // inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1 @@ -292,6 +300,74 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall( return Status::OK(); } +// Finds the index of the device from the device name list. +Status FindDeviceIndex(const utils::MutableNodeView* device_index_node, + const string& device, int* index) { + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(device, &parsed_name) || + !parsed_name.has_type) { + return errors::Internal("Could not parse device name:", device); + } + const auto& device_list = + device_index_node->GetAttr("device_names")->list().s(); + auto it = absl::c_find(device_list, parsed_name.type); + if (it != device_list.end()) { + *index = it - device_list.begin(); + } else { + // Sets *index to device_list.size() because the default_fn is guaranteed to + // be the final item in the case op branching list. + *index = device_list.size(); + } + return Status::OK(); +} + +// Rewrites the device_index op to a const op with value of the index. +void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node, + int index) { + // Modifies the DeviceIndex node to be an Const op with correct device index. + auto node = device_index_node->node(); + node->set_op(kConstOp); + node->clear_attr(); + (*node->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*node->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + tensor->add_int_val(index); + VLOG(2) << "Node after rewriting:" << node->DebugString(); +} + +Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const { + Status status; + VLOG(2) << "graph before rewriting device index:" << graph->DebugString(); + utils::MutableGraphView graph_view(graph, &status); + TF_RETURN_IF_ERROR(status); + const int num_nodes = graph_view.NumNodes(); + for (int k = 0; k < num_nodes; ++k) { + auto* node_view = graph_view.GetNode(k); + if (node_view->GetOp() != kDeviceIndexOp) { + continue; + } + VLOG(2) << "Found a node to rewrite the device index"; + + // Find the case node with device index node as input, rewrite the + // DeviceIndex node to have the value of the index of device type of the + // case node. + for (const auto& fanouts : node_view->GetRegularFanouts()) { + for (const auto& fanout : fanouts) { + if (fanout.node_view()->GetOp() != kCaseOp) continue; + int index; + // If any error is thrown out during device parsing, we simply skip + // and do not modify the DeviceIndexNode. + Status status = + FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index); + if (status.ok()) { + RewriteDeviceIndexOp(node_view, index); + } + } + } + } + return Status::OK(); +} + Status ImplementationSelector::SelectImplementation(GraphDef* graph) const { if (!graph->has_library()) { VLOG(2) << "Skipping graph since it does not have function def"; @@ -307,8 +383,9 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const { TF_RETURN_IF_ERROR(status); const int num_nodes = graph_view.NumNodes(); - for (int k = 0; k < num_nodes; ++k) + for (int k = 0; k < num_nodes; ++k) { TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k))); + } return Status::OK(); } @@ -326,7 +403,13 @@ Status ImplementationSelector::Optimize(Cluster* cluster, << "libraries: " << status; return errors::Aborted("Skipped Optimization"); } + *optimized_graph = item.graph; + status = SelectDeviceIndex(optimized_graph); + if (!status.ok()) { + *optimized_graph = item.graph; + VLOG(2) << "Could not rewrite device index due to error:" << status; + } return SelectImplementation(optimized_graph); } diff --git a/tensorflow/core/grappler/optimizers/implementation_selector.h b/tensorflow/core/grappler/optimizers/implementation_selector.h index 57d19fe7046..f6962e0a10d 100644 --- a/tensorflow/core/grappler/optimizers/implementation_selector.h +++ b/tensorflow/core/grappler/optimizers/implementation_selector.h @@ -34,6 +34,28 @@ limitations under the License. namespace tensorflow { namespace grappler { +// Motivation: To achieve the same high level functionality, the underlying +// implementations sometimes are different for various devices where the +// function runs. In order to achieve the correct result and best performance, +// the proper implementation needs to be picked dynamically. +// +// Currently there are two approaches to do this. +// (1) Utilize case op and dynamacically change the branch index. +// (2) Swap function implementation, it will be deprecated. +// +// Idea for approach 1. +// This transformation rewrites the DeviceIndex op with a Const op with value +// of the index of the device the associcated Case op runs. +// Example: +// def plus_one_gpu(x): return x + 1.0 +// def plus_one_reference_implementation(x): return x + 1.0 +// input = tf.constant(2.0, dtype=tf.float32) +// cpu_fn = lambda:plus_one_reference_implementation(input) +// gpu_fn = lambda:plus_one_gpu(input) +// control_flow_ops.execute_fn_for_device( +// {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn) +// +// Idea for approach 2. // This transformation replaces function calls by the appropriate function // definition based on properties of the runtime system. For instance, // we may choose one implementation over another if we have a GPU with @@ -58,7 +80,8 @@ namespace grappler { // z = plus_one_gpu(input) // print(sess.run(z)) // -// At runtime, we will trim either `plus_one_gpu` or + +// At runtime, we will select either `plus_one_gpu` or // `plus_one_reference_implementation` based on the availability of the GPU. // // Available annotations: @@ -106,6 +129,68 @@ class ImplementationSelector : public CustomGraphOptimizer { // gradients. Status SelectImplementation(GraphDef* graph) const; + // Rewrites the DeviceIndex op with a Const op with value of the index of the + // device the associcated Case op runs. + + // This function first looks up all the DeviceIndex ops. + // Then for each of these ops, it finds the device of the + // associated Case op that takes the DeviceIndex op as the input, and + // caculates the index of the device in the device list of DeviceIndex op. + // Lastly, it rewrites the DeviceIndex op with a Const op and sets the value + // to be the index. + // + // Example input nodes: + // node { + // name: "x" + // op: "DeviceIndex" + // device: "/device:CPU:0" + // attr { + // key: "device_names" + // value { + // list { + // s: "CPU" + // s: "TPU_REPLICATED_CORE" + // s: "GPU" + // } + // } + // } + // } + // node { + // name: "case" + // op: "Case" + // input: "x" + // device: "/device:GPU:0" + // ... + // } + // Example output nodes: + // + // name: "x" + // op: "Const" + // device: "/device:CPU:0" + // attr { + // key: "dtype" + // value { + // type: DT_INT32 + // } + // } + // attr { + // key: "value" + // value { + // tensor { + // dtype: DT_INT32 + // int_val: 2 + // } + // } + // } + // node { + // name: "case" + // op: "Case" + // input: "x" + // device: "/device:GPU:0" + // ... + // } + Status SelectDeviceIndex(GraphDef* graph) const; + std::unique_ptr lib_info_; TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector); diff --git a/tensorflow/core/grappler/optimizers/implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/implementation_selector_test.cc index 914570fcadb..2ef8bb878cc 100644 --- a/tensorflow/core/grappler/optimizers/implementation_selector_test.cc +++ b/tensorflow/core/grappler/optimizers/implementation_selector_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -58,6 +59,167 @@ TEST_F(ImplementationSelectorTest, NoUpdate) { EXPECT_EQ(item.graph.node_size(), output.node_size()); } +TEST_F(ImplementationSelectorTest, SelectDeviceIndex) { + using test::function::NDef; + ImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + AttrValue device_names; + device_names.mutable_list()->add_s("CPU"); + device_names.mutable_list()->add_s("GPU"); + item.graph = test::function::GDef( + {NDef("x", "DeviceIndex", {}, {{"device_names", device_names}}, + CpuDevice), + NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + // Rewrite DeviceIndex op to a Const op with value of GPU index 1. + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0)); + } + } +} + +TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) { + using test::function::NDef; + ImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + AttrValue device_names; + device_names.mutable_list()->add_s("CPU"); + device_names.mutable_list()->add_s("TPU_REPLICATED_CORE"); + device_names.mutable_list()->add_s("GPU"); + item.graph = test::function::GDef( + {NDef("x", "DeviceIndex", {}, {{"device_names", device_names}}, + CpuDevice), + NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("y", "DeviceIndex", {}, {{"device_names", device_names}}, + GpuDevice), + NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice)}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + // Rewrite DeviceIndex op to a Const op with value of GPU index 1. + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0)); + } + if (node.name() == "y") { + // Rewrite DeviceIndex op to a Const op with value of CPU index 0. + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0)); + } + } +} + +TEST_F(ImplementationSelectorTest, SelectDeviceIndexNotFound) { + using test::function::NDef; + ImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + AttrValue device_names; + device_names.mutable_list()->add_s("CPU"); + device_names.mutable_list()->add_s("GPU"); + item.graph = test::function::GDef( + {NDef("x", "DeviceIndex", {}, {{"device_names", device_names}}, + CpuDevice), + NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, TpuDevice)}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + // Rewrite DeviceIndex op to a Const op with value of device names length. + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0)); + } + } +} + +TEST_F(ImplementationSelectorTest, SelectDeviceIndexError) { + using test::function::NDef; + ImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + AttrValue device_names; + device_names.mutable_list()->add_s("CPU"); + device_names.mutable_list()->add_s("GPU"); + item.graph = test::function::GDef( + {NDef("x", "DeviceIndex", {}, {{"device_names", device_names}}, + CpuDevice), + NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, "")}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + // Device parse has error, do not rewrite the DeviceIndexNode. + EXPECT_EQ("DeviceIndex", node.op()); + } + } +} + +TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) { + using test::function::NDef; + ImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + // DeviceIndex op based implementation selector. + AttrValue device_names; + device_names.mutable_list()->add_s("CPU"); + device_names.mutable_list()->add_s("TPU_REPLICATED_CORE"); + device_names.mutable_list()->add_s("GPU"); + + // Function swap based implementation selector. + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["api_implements"].set_s("times_two"); + (*func_attr)["api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XAddX(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["api_implements"].set_s("times_two"); + (*func2_attr)["api_preferred_device"].set_s("GPU"); + + item.graph = test::function::GDef( + {NDef("x", "DeviceIndex", {}, {{"device_names", device_names}}, + CpuDevice), + NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("y", "DeviceIndex", {}, {{"device_names", device_names}}, + GpuDevice), + NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice), + NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + // Rewrite DeviceIndex op to a Const op with value of GPU index 1. + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0)); + } + if (node.name() == "y") { + // Rewrite DeviceIndex op to a Const op with value of CPU index 0. + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0)); + } + if (node.name() == "y1") { + // Make sure the implementation has been swapped to use the GPU version. + EXPECT_EQ("XAddX", node.op()); + } else if (node.name() == "y2") { + // Make sure the implementation is not changed. + EXPECT_EQ("XTimesTwo", node.op()); + } + } +} + TEST_F(ImplementationSelectorTest, SwapImplementation) { using test::function::NDef; auto cpu_def = test::function::XTimesTwo(); diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 7f4d1144cb2..96c0a3d6bdc 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -924,5 +924,37 @@ class FakeParamOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp); REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp); +// DeviceIndexOP returns the current device index. +class DeviceIndexOp : public OpKernel { + public: + explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_)); + } + + void Compute(OpKernelContext* ctx) override { + Tensor* device_name_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &device_name_t)); + DeviceNameUtils::ParsedName parsed_name; + int index = device_names_.size(); + if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) && + parsed_name.has_type) { + auto it = absl::c_find(device_names_, parsed_name.type); + if (it != device_names_.end()) { + index = it - device_names_.begin(); + } + } + device_name_t->scalar()() = index; + } + + private: + PersistentTensor value_handle_; + std::vector device_names_; +}; + +REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp); +REGISTER_KERNEL_BUILDER( + Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 0a08925f7e1..11b10f3c504 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -299,4 +299,10 @@ REGISTER_OP("FakeParam") return Status::OK(); }); +// Returns the device index. +REGISTER_OP("DeviceIndex") + .Output("index: int32") + .Attr("device_names: list(string)") + .SetShapeFn(shape_inference::ScalarShape); + } // end namespace tensorflow diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index dababc7615e..479d1122742 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -942,7 +942,10 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph): return captured_tensor -def indexed_case(branch_index, branch_fns, name="indexed_case"): +def indexed_case(branch_index, + branch_fns, + name="indexed_case", + lower_using_switch_merge=None): """Like conv_v2, except emits a Case op instead of an If.""" if isinstance(branch_index, int): raise TypeError("branch_index must not be a Python int", branch_index) @@ -976,7 +979,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"): return _build_case( branch_index, branch_graphs, [g.external_captures for g in branch_graphs], - name=scope) + name=scope, + lower_using_switch_merge=lower_using_switch_merge) @ops.RegisterGradient("Case") @@ -1064,7 +1068,11 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name return [None] + outputs -def _build_case(branch_index, branch_graphs, branch_inputs, name=None): +def _build_case(branch_index, + branch_graphs, + branch_inputs, + name=None, + lower_using_switch_merge=None): """Creates an `Case` op from `branch_index`, branch graphs and inputs. Note that this modifies `branch_graphs` to make the inputs match, and to @@ -1080,6 +1088,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None): branch_inputs: List of lists of Tensors to be passed to corresponding branch_graph as input. name: the name for the Case op. + lower_using_switch_merge: Lower this op using switch merge ops (optional). Returns: A list of Tensors which are the outputs of the Case op. Does not include @@ -1105,7 +1114,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None): case_op, tensors = _get_op_and_outputs(tensors) if case_op is not None: - util.maybe_set_lowering_attr(case_op) + util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) util.maybe_propagate_compile_time_consts_in_xla(case_op) _set_read_only_resource_inputs_attr(case_op, branch_graphs) # Prevent fetching since the variant outputs can't be fetched directly. diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 918c989432d..3398308d42e 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util as util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -3283,7 +3284,11 @@ def _indexed_case_verify_and_canonicalize_args(branch_fns, default, return actions -def _indexed_case_helper(branch_fns, default, branch_index, name): +def _indexed_case_helper(branch_fns, + default, + branch_index, + name, + lower_using_switch_merge=None): """Implementation of case that emits the n-way indexed Case op. Args: @@ -3293,6 +3298,7 @@ def _indexed_case_helper(branch_fns, default, branch_index, name): branch_index: Optional int `Tensor`, which selects for the corresponding pred_fn_pair. name: A name for this operation (optional). + lower_using_switch_merge: Lower this op using switch merge ops (optional). Returns: The tensors returned by the pair whose key matched branch_index, or @@ -3314,7 +3320,10 @@ def _indexed_case_helper(branch_fns, default, branch_index, name): | math_ops.greater_equal(branch_index, len(branch_fns)), len(branch_fns) - 1, branch_index) return branch_fns[int(branch_index)]() - return cond_v2.indexed_case(branch_index, branch_fns) + return cond_v2.indexed_case( + branch_index, + branch_fns, + lower_using_switch_merge=lower_using_switch_merge) @tf_export("case", v1=[]) @@ -3607,6 +3616,50 @@ def switch_case(branch_index, return _indexed_case_helper(branch_fns, default, branch_index, name) +def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"): + """Executes one of the provided callables based on the device placement. + + This API is used when the implementations for high level function depend on + the underlying device placement. It takes a dictionary of device type to + callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of + the device where to run this op matches the key in 'device_branch_fns', + the corresponding callable is executed, falling back to 'default_fn' if none + matches. + + **Example:** + ```python + def f1(): return tf.constant(1) + def f2(): return tf.constant(2) + r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1) + ``` + 'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on + any other device types. + + + Args: + device_branch_fns: a dictionary of device types to the callables. Each + callable must return a matching structure of tensors. + default_fn: fallback callable when the underlying device does not match any + key in the 'device_branch_fns'. + name: A name for this operation (optional). + + Returns: + The tensors returned by the callable identified by device type during + execution, or those returned by 'default_fn' if no key matches. + """ + + device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()} + branch_fns = list(device_branch_fns_upper.values()) + devices = list(device_branch_fns_upper.keys()) + device_index = gen_functional_ops.device_index(device_names=devices) + return _indexed_case_helper( + branch_fns, + default_fn, + device_index, + name, + lower_using_switch_merge=False) + + class XLAControlFlowContext(ControlFlowContext): """Base class for XLA and TPU control flow contexts.""" diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 2979eb79bfd..f4459d8e34a 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -1211,6 +1211,92 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase): control_flow_ops.switch_case(array_ops.constant(1), branches) +class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase): + + def testCommonCases(self): + + def cpu_fn(x): + return x + x + + def gpu_fn(x): + return x * x + + def flexible_fn(a): + branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)} + return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a)) + + @def_function.function + def flexible_defun(a): + return flexible_fn(a) + + def run_defun_and_tape(a): + with backprop.GradientTape() as tape: + tape.watch(a) + result = flexible_defun(a) + grad = tape.gradient(result, a) + r = flexible_fn(a) + return r, result, grad + + a = array_ops.constant(3.) + with ops.device("cpu:0"): + r, result, grad = run_defun_and_tape(a) + self.assertEqual(6., self.evaluate(r)) + self.assertEqual(6., self.evaluate(result)) + self.assertEqual([2.], self.evaluate(grad)) + + if test_util.is_gpu_available(): + with ops.device("gpu:0"): + r, result, grad = run_defun_and_tape(a) + self.assertEqual(9., self.evaluate(r)) + self.assertEqual(9., self.evaluate(result)) + self.assertEqual([6.], self.evaluate(grad)) + + # no device annotation + r, result, grad = run_defun_and_tape(a) + if test_util.is_gpu_available(): + self.assertEqual(9., self.evaluate(r)) + self.assertEqual(9., self.evaluate(result)) + self.assertEqual([6.], self.evaluate(grad)) + else: + self.assertEqual(6., self.evaluate(r)) + self.assertEqual(6., self.evaluate(result)) + self.assertEqual([2.], self.evaluate(grad)) + + def testFallBack(self): + + def default_fn(x): + return x + + def tpu_fn(x): + return x * x * x + + def flexible_fn(a): + branches = {"TPU": lambda: tpu_fn(a)} + return control_flow_ops.execute_fn_for_device( + branches, default_fn=lambda: default_fn(a)) + + @def_function.function + def flexible_defun(a): + return flexible_fn(a) + + a = array_ops.constant(3.) + with ops.device("cpu:0"): + result_defun = flexible_defun(a) + result_defun = flexible_fn(a) + self.assertEqual(3., self.evaluate(result_defun)) + # execute_fn_for_device is not inside defun_function. + result = flexible_fn(a) + self.assertEqual(3., self.evaluate(result)) + + if test_util.is_gpu_available(): + with ops.device("gpu:0"): + result_defun = flexible_defun(a) + self.assertEqual(3., self.evaluate(result_defun)) + # execute_fn_for_device is not inside defun_function. + result = flexible_fn(a) + self.assertEqual(3., self.evaluate(result)) + + class CaseTest(test_util.TensorFlowTestCase): @test_util.run_deprecated_v1 diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index 7e87d25fe99..4fc464c545c 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -92,7 +92,7 @@ def unique_grad_fn_name(forward_name): return "%s_grad_%s" % (forward_name, ops.uid()) -def maybe_set_lowering_attr(op): +def maybe_set_lowering_attr(op, lower_using_switch_merge=None): """Sets the flag to enable lowering on `op` if necessary. Lowering allows cond_v2 and while_v2 to avoid some of the limitations of @@ -108,14 +108,21 @@ def maybe_set_lowering_attr(op): - When the eager execution context specifies the executor of functions to be the single threaded executor (see context.function_executor_type()). Because the single threaded executor does not support v1 control flow ops. + - When 'lower_using_switch_merge' is explicitly set to False. Args: op: An `If` or `While` Operation. + lower_using_switch_merge: Explicit value to lower or not (optional). """ - if (not _DISABLE_LOWER_USING_SWITCH_MERGE and - not control_flow_util.GraphOrParentsInXlaContext(op.graph) and - context.context().function_call_options.executor_type != - "SINGLE_THREADED_EXECUTOR"): + if lower_using_switch_merge is not None: + # pylint: disable=protected-access + op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=lower_using_switch_merge)) + # pylint: enable=protected-access + elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and + not control_flow_util.GraphOrParentsInXlaContext(op.graph) and + context.context().function_call_options.executor_type != + "SINGLE_THREADED_EXECUTOR"): # pylint: disable=protected-access op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index a8efb9e59b5..25ae132c775 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1140,6 +1140,10 @@ tf_module { name: "DestroyTemporaryVariable" argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DeviceIndex" + argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Diag" argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index a8efb9e59b5..25ae132c775 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1140,6 +1140,10 @@ tf_module { name: "DestroyTemporaryVariable" argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DeviceIndex" + argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Diag" argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "