Add a function to dynamic choose and execute the proper implementation based on underlying device placement

PiperOrigin-RevId: 313605766
Change-Id: I877b684dcef782b375df0504c0250acd9e808ce9
This commit is contained in:
Yanhua Sun 2020-05-28 09:51:29 -07:00 committed by TensorFlower Gardener
parent 4594b8cfb5
commit 2de551ba87
13 changed files with 550 additions and 13 deletions

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "DeviceIndex"
visibility: HIDDEN
summary: "Return the index of device the op runs."
}

View File

@ -1062,6 +1062,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",

View File

@ -17,9 +17,12 @@ limitations under the License.
#include <string> #include <string>
#include "absl/strings/match.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
@ -36,6 +39,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
constexpr char kConstOp[] = "Const";
constexpr char kCaseOp[] = "Case";
constexpr char kDeviceIndexOp[] = "DeviceIndex";
// TODO(b/157615690): clean up function implementation swap code.
// The overall idea for the function swap is like below: // The overall idea for the function swap is like below:
// ----------- ----------- // ----------- -----------
// inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1 // inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1
@ -292,6 +300,74 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
return Status::OK(); return Status::OK();
} }
// Finds the index of the device from the device name list.
Status FindDeviceIndex(const utils::MutableNodeView* device_index_node,
const string& device, int* index) {
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(device, &parsed_name) ||
!parsed_name.has_type) {
return errors::Internal("Could not parse device name:", device);
}
const auto& device_list =
device_index_node->GetAttr("device_names")->list().s();
auto it = absl::c_find(device_list, parsed_name.type);
if (it != device_list.end()) {
*index = it - device_list.begin();
} else {
// Sets *index to device_list.size() because the default_fn is guaranteed to
// be the final item in the case op branching list.
*index = device_list.size();
}
return Status::OK();
}
// Rewrites the device_index op to a const op with value of the index.
void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
int index) {
// Modifies the DeviceIndex node to be an Const op with correct device index.
auto node = device_index_node->node();
node->set_op(kConstOp);
node->clear_attr();
(*node->mutable_attr())["dtype"].set_type(DT_INT32);
auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
tensor->set_dtype(DT_INT32);
tensor->add_int_val(index);
VLOG(2) << "Node after rewriting:" << node->DebugString();
}
Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
Status status;
VLOG(2) << "graph before rewriting device index:" << graph->DebugString();
utils::MutableGraphView graph_view(graph, &status);
TF_RETURN_IF_ERROR(status);
const int num_nodes = graph_view.NumNodes();
for (int k = 0; k < num_nodes; ++k) {
auto* node_view = graph_view.GetNode(k);
if (node_view->GetOp() != kDeviceIndexOp) {
continue;
}
VLOG(2) << "Found a node to rewrite the device index";
// Find the case node with device index node as input, rewrite the
// DeviceIndex node to have the value of the index of device type of the
// case node.
for (const auto& fanouts : node_view->GetRegularFanouts()) {
for (const auto& fanout : fanouts) {
if (fanout.node_view()->GetOp() != kCaseOp) continue;
int index;
// If any error is thrown out during device parsing, we simply skip
// and do not modify the DeviceIndexNode.
Status status =
FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index);
if (status.ok()) {
RewriteDeviceIndexOp(node_view, index);
}
}
}
}
return Status::OK();
}
Status ImplementationSelector::SelectImplementation(GraphDef* graph) const { Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
if (!graph->has_library()) { if (!graph->has_library()) {
VLOG(2) << "Skipping graph since it does not have function def"; VLOG(2) << "Skipping graph since it does not have function def";
@ -307,8 +383,9 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
const int num_nodes = graph_view.NumNodes(); const int num_nodes = graph_view.NumNodes();
for (int k = 0; k < num_nodes; ++k) for (int k = 0; k < num_nodes; ++k) {
TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k))); TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
}
return Status::OK(); return Status::OK();
} }
@ -326,7 +403,13 @@ Status ImplementationSelector::Optimize(Cluster* cluster,
<< "libraries: " << status; << "libraries: " << status;
return errors::Aborted("Skipped Optimization"); return errors::Aborted("Skipped Optimization");
} }
*optimized_graph = item.graph; *optimized_graph = item.graph;
status = SelectDeviceIndex(optimized_graph);
if (!status.ok()) {
*optimized_graph = item.graph;
VLOG(2) << "Could not rewrite device index due to error:" << status;
}
return SelectImplementation(optimized_graph); return SelectImplementation(optimized_graph);
} }

View File

@ -34,6 +34,28 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
// Motivation: To achieve the same high level functionality, the underlying
// implementations sometimes are different for various devices where the
// function runs. In order to achieve the correct result and best performance,
// the proper implementation needs to be picked dynamically.
//
// Currently there are two approaches to do this.
// (1) Utilize case op and dynamacically change the branch index.
// (2) Swap function implementation, it will be deprecated.
//
// Idea for approach 1.
// This transformation rewrites the DeviceIndex op with a Const op with value
// of the index of the device the associcated Case op runs.
// Example:
// def plus_one_gpu(x): return x + 1.0
// def plus_one_reference_implementation(x): return x + 1.0
// input = tf.constant(2.0, dtype=tf.float32)
// cpu_fn = lambda:plus_one_reference_implementation(input)
// gpu_fn = lambda:plus_one_gpu(input)
// control_flow_ops.execute_fn_for_device(
// {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn)
//
// Idea for approach 2.
// This transformation replaces function calls by the appropriate function // This transformation replaces function calls by the appropriate function
// definition based on properties of the runtime system. For instance, // definition based on properties of the runtime system. For instance,
// we may choose one implementation over another if we have a GPU with // we may choose one implementation over another if we have a GPU with
@ -58,7 +80,8 @@ namespace grappler {
// z = plus_one_gpu(input) // z = plus_one_gpu(input)
// print(sess.run(z)) // print(sess.run(z))
// //
// At runtime, we will trim either `plus_one_gpu` or
// At runtime, we will select either `plus_one_gpu` or
// `plus_one_reference_implementation` based on the availability of the GPU. // `plus_one_reference_implementation` based on the availability of the GPU.
// //
// Available annotations: // Available annotations:
@ -106,6 +129,68 @@ class ImplementationSelector : public CustomGraphOptimizer {
// gradients. // gradients.
Status SelectImplementation(GraphDef* graph) const; Status SelectImplementation(GraphDef* graph) const;
// Rewrites the DeviceIndex op with a Const op with value of the index of the
// device the associcated Case op runs.
// This function first looks up all the DeviceIndex ops.
// Then for each of these ops, it finds the device of the
// associated Case op that takes the DeviceIndex op as the input, and
// caculates the index of the device in the device list of DeviceIndex op.
// Lastly, it rewrites the DeviceIndex op with a Const op and sets the value
// to be the index.
//
// Example input nodes:
// node {
// name: "x"
// op: "DeviceIndex"
// device: "/device:CPU:0"
// attr {
// key: "device_names"
// value {
// list {
// s: "CPU"
// s: "TPU_REPLICATED_CORE"
// s: "GPU"
// }
// }
// }
// }
// node {
// name: "case"
// op: "Case"
// input: "x"
// device: "/device:GPU:0"
// ...
// }
// Example output nodes:
//
// name: "x"
// op: "Const"
// device: "/device:CPU:0"
// attr {
// key: "dtype"
// value {
// type: DT_INT32
// }
// }
// attr {
// key: "value"
// value {
// tensor {
// dtype: DT_INT32
// int_val: 2
// }
// }
// }
// node {
// name: "case"
// op: "Case"
// input: "x"
// device: "/device:GPU:0"
// ...
// }
Status SelectDeviceIndex(GraphDef* graph) const;
std::unique_ptr<FunctionLibraryApiInfo> lib_info_; std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector); TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
@ -58,6 +59,167 @@ TEST_F(ImplementationSelectorTest, NoUpdate) {
EXPECT_EQ(item.graph.node_size(), output.node_size()); EXPECT_EQ(item.graph.node_size(), output.node_size());
} }
TEST_F(ImplementationSelectorTest, SelectDeviceIndex) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
}
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
GpuDevice),
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice)});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
}
if (node.name() == "y") {
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
}
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexNotFound) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, TpuDevice)});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of device names length.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
}
}
}
TEST_F(ImplementationSelectorTest, SelectDeviceIndexError) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, "")});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Device parse has error, do not rewrite the DeviceIndexNode.
EXPECT_EQ("DeviceIndex", node.op());
}
}
}
TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) {
using test::function::NDef;
ImplementationSelector optimizer;
GraphDef output;
GrapplerItem item;
// DeviceIndex op based implementation selector.
AttrValue device_names;
device_names.mutable_list()->add_s("CPU");
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
device_names.mutable_list()->add_s("GPU");
// Function swap based implementation selector.
auto cpu_def = test::function::XTimesTwo();
auto* func_attr = cpu_def.mutable_attr();
(*func_attr)["api_implements"].set_s("times_two");
(*func_attr)["api_preferred_device"].set_s("CPU");
auto gpu_def = test::function::XAddX();
auto* func2_attr = gpu_def.mutable_attr();
(*func2_attr)["api_implements"].set_s("times_two");
(*func2_attr)["api_preferred_device"].set_s("GPU");
item.graph = test::function::GDef(
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
CpuDevice),
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
GpuDevice),
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice),
NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
// FunctionLib
{cpu_def, gpu_def});
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
if (node.name() == "x") {
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
}
if (node.name() == "y") {
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
}
if (node.name() == "y1") {
// Make sure the implementation has been swapped to use the GPU version.
EXPECT_EQ("XAddX", node.op());
} else if (node.name() == "y2") {
// Make sure the implementation is not changed.
EXPECT_EQ("XTimesTwo", node.op());
}
}
}
TEST_F(ImplementationSelectorTest, SwapImplementation) { TEST_F(ImplementationSelectorTest, SwapImplementation) {
using test::function::NDef; using test::function::NDef;
auto cpu_def = test::function::XTimesTwo(); auto cpu_def = test::function::XTimesTwo();

View File

@ -924,5 +924,37 @@ class FakeParamOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp); REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp); REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
// DeviceIndexOP returns the current device index.
class DeviceIndexOp : public OpKernel {
public:
explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
}
void Compute(OpKernelContext* ctx) override {
Tensor* device_name_t;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({}), &device_name_t));
DeviceNameUtils::ParsedName parsed_name;
int index = device_names_.size();
if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
parsed_name.has_type) {
auto it = absl::c_find(device_names_, parsed_name.type);
if (it != device_names_.end()) {
index = it - device_names_.begin();
}
}
device_name_t->scalar<int32>()() = index;
}
private:
PersistentTensor value_handle_;
std::vector<string> device_names_;
};
REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
REGISTER_KERNEL_BUILDER(
Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp);
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -299,4 +299,10 @@ REGISTER_OP("FakeParam")
return Status::OK(); return Status::OK();
}); });
// Returns the device index.
REGISTER_OP("DeviceIndex")
.Output("index: int32")
.Attr("device_names: list(string)")
.SetShapeFn(shape_inference::ScalarShape);
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -942,7 +942,10 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
return captured_tensor return captured_tensor
def indexed_case(branch_index, branch_fns, name="indexed_case"): def indexed_case(branch_index,
branch_fns,
name="indexed_case",
lower_using_switch_merge=None):
"""Like conv_v2, except emits a Case op instead of an If.""" """Like conv_v2, except emits a Case op instead of an If."""
if isinstance(branch_index, int): if isinstance(branch_index, int):
raise TypeError("branch_index must not be a Python int", branch_index) raise TypeError("branch_index must not be a Python int", branch_index)
@ -976,7 +979,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"):
return _build_case( return _build_case(
branch_index, branch_index,
branch_graphs, [g.external_captures for g in branch_graphs], branch_graphs, [g.external_captures for g in branch_graphs],
name=scope) name=scope,
lower_using_switch_merge=lower_using_switch_merge)
@ops.RegisterGradient("Case") @ops.RegisterGradient("Case")
@ -1064,7 +1068,11 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
return [None] + outputs return [None] + outputs
def _build_case(branch_index, branch_graphs, branch_inputs, name=None): def _build_case(branch_index,
branch_graphs,
branch_inputs,
name=None,
lower_using_switch_merge=None):
"""Creates an `Case` op from `branch_index`, branch graphs and inputs. """Creates an `Case` op from `branch_index`, branch graphs and inputs.
Note that this modifies `branch_graphs` to make the inputs match, and to Note that this modifies `branch_graphs` to make the inputs match, and to
@ -1080,6 +1088,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
branch_inputs: List of lists of Tensors to be passed to corresponding branch_inputs: List of lists of Tensors to be passed to corresponding
branch_graph as input. branch_graph as input.
name: the name for the Case op. name: the name for the Case op.
lower_using_switch_merge: Lower this op using switch merge ops (optional).
Returns: Returns:
A list of Tensors which are the outputs of the Case op. Does not include A list of Tensors which are the outputs of the Case op. Does not include
@ -1105,7 +1114,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
case_op, tensors = _get_op_and_outputs(tensors) case_op, tensors = _get_op_and_outputs(tensors)
if case_op is not None: if case_op is not None:
util.maybe_set_lowering_attr(case_op) util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
util.maybe_propagate_compile_time_consts_in_xla(case_op) util.maybe_propagate_compile_time_consts_in_xla(case_op)
_set_read_only_resource_inputs_attr(case_op, branch_graphs) _set_read_only_resource_inputs_attr(case_op, branch_graphs)
# Prevent fetching since the variant outputs can't be fetched directly. # Prevent fetching since the variant outputs can't be fetched directly.

View File

@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util as util from tensorflow.python.ops import control_flow_util as util
from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -3283,7 +3284,11 @@ def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
return actions return actions
def _indexed_case_helper(branch_fns, default, branch_index, name): def _indexed_case_helper(branch_fns,
default,
branch_index,
name,
lower_using_switch_merge=None):
"""Implementation of case that emits the n-way indexed Case op. """Implementation of case that emits the n-way indexed Case op.
Args: Args:
@ -3293,6 +3298,7 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
branch_index: Optional int `Tensor`, which selects for the corresponding branch_index: Optional int `Tensor`, which selects for the corresponding
pred_fn_pair. pred_fn_pair.
name: A name for this operation (optional). name: A name for this operation (optional).
lower_using_switch_merge: Lower this op using switch merge ops (optional).
Returns: Returns:
The tensors returned by the pair whose key matched branch_index, or The tensors returned by the pair whose key matched branch_index, or
@ -3314,7 +3320,10 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
| math_ops.greater_equal(branch_index, len(branch_fns)), | math_ops.greater_equal(branch_index, len(branch_fns)),
len(branch_fns) - 1, branch_index) len(branch_fns) - 1, branch_index)
return branch_fns[int(branch_index)]() return branch_fns[int(branch_index)]()
return cond_v2.indexed_case(branch_index, branch_fns) return cond_v2.indexed_case(
branch_index,
branch_fns,
lower_using_switch_merge=lower_using_switch_merge)
@tf_export("case", v1=[]) @tf_export("case", v1=[])
@ -3607,6 +3616,50 @@ def switch_case(branch_index,
return _indexed_case_helper(branch_fns, default, branch_index, name) return _indexed_case_helper(branch_fns, default, branch_index, name)
def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
"""Executes one of the provided callables based on the device placement.
This API is used when the implementations for high level function depend on
the underlying device placement. It takes a dictionary of device type to
callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
the device where to run this op matches the key in 'device_branch_fns',
the corresponding callable is executed, falling back to 'default_fn' if none
matches.
**Example:**
```python
def f1(): return tf.constant(1)
def f2(): return tf.constant(2)
r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
```
'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
any other device types.
Args:
device_branch_fns: a dictionary of device types to the callables. Each
callable must return a matching structure of tensors.
default_fn: fallback callable when the underlying device does not match any
key in the 'device_branch_fns'.
name: A name for this operation (optional).
Returns:
The tensors returned by the callable identified by device type during
execution, or those returned by 'default_fn' if no key matches.
"""
device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
branch_fns = list(device_branch_fns_upper.values())
devices = list(device_branch_fns_upper.keys())
device_index = gen_functional_ops.device_index(device_names=devices)
return _indexed_case_helper(
branch_fns,
default_fn,
device_index,
name,
lower_using_switch_merge=False)
class XLAControlFlowContext(ControlFlowContext): class XLAControlFlowContext(ControlFlowContext):
"""Base class for XLA and TPU control flow contexts.""" """Base class for XLA and TPU control flow contexts."""

View File

@ -1211,6 +1211,92 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
control_flow_ops.switch_case(array_ops.constant(1), branches) control_flow_ops.switch_case(array_ops.constant(1), branches)
class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase):
def testCommonCases(self):
def cpu_fn(x):
return x + x
def gpu_fn(x):
return x * x
def flexible_fn(a):
branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)}
return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a))
@def_function.function
def flexible_defun(a):
return flexible_fn(a)
def run_defun_and_tape(a):
with backprop.GradientTape() as tape:
tape.watch(a)
result = flexible_defun(a)
grad = tape.gradient(result, a)
r = flexible_fn(a)
return r, result, grad
a = array_ops.constant(3.)
with ops.device("cpu:0"):
r, result, grad = run_defun_and_tape(a)
self.assertEqual(6., self.evaluate(r))
self.assertEqual(6., self.evaluate(result))
self.assertEqual([2.], self.evaluate(grad))
if test_util.is_gpu_available():
with ops.device("gpu:0"):
r, result, grad = run_defun_and_tape(a)
self.assertEqual(9., self.evaluate(r))
self.assertEqual(9., self.evaluate(result))
self.assertEqual([6.], self.evaluate(grad))
# no device annotation
r, result, grad = run_defun_and_tape(a)
if test_util.is_gpu_available():
self.assertEqual(9., self.evaluate(r))
self.assertEqual(9., self.evaluate(result))
self.assertEqual([6.], self.evaluate(grad))
else:
self.assertEqual(6., self.evaluate(r))
self.assertEqual(6., self.evaluate(result))
self.assertEqual([2.], self.evaluate(grad))
def testFallBack(self):
def default_fn(x):
return x
def tpu_fn(x):
return x * x * x
def flexible_fn(a):
branches = {"TPU": lambda: tpu_fn(a)}
return control_flow_ops.execute_fn_for_device(
branches, default_fn=lambda: default_fn(a))
@def_function.function
def flexible_defun(a):
return flexible_fn(a)
a = array_ops.constant(3.)
with ops.device("cpu:0"):
result_defun = flexible_defun(a)
result_defun = flexible_fn(a)
self.assertEqual(3., self.evaluate(result_defun))
# execute_fn_for_device is not inside defun_function.
result = flexible_fn(a)
self.assertEqual(3., self.evaluate(result))
if test_util.is_gpu_available():
with ops.device("gpu:0"):
result_defun = flexible_defun(a)
self.assertEqual(3., self.evaluate(result_defun))
# execute_fn_for_device is not inside defun_function.
result = flexible_fn(a)
self.assertEqual(3., self.evaluate(result))
class CaseTest(test_util.TensorFlowTestCase): class CaseTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1

View File

@ -92,7 +92,7 @@ def unique_grad_fn_name(forward_name):
return "%s_grad_%s" % (forward_name, ops.uid()) return "%s_grad_%s" % (forward_name, ops.uid())
def maybe_set_lowering_attr(op): def maybe_set_lowering_attr(op, lower_using_switch_merge=None):
"""Sets the flag to enable lowering on `op` if necessary. """Sets the flag to enable lowering on `op` if necessary.
Lowering allows cond_v2 and while_v2 to avoid some of the limitations of Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
@ -108,11 +108,18 @@ def maybe_set_lowering_attr(op):
- When the eager execution context specifies the executor of functions to - When the eager execution context specifies the executor of functions to
be the single threaded executor (see context.function_executor_type()). be the single threaded executor (see context.function_executor_type()).
Because the single threaded executor does not support v1 control flow ops. Because the single threaded executor does not support v1 control flow ops.
- When 'lower_using_switch_merge' is explicitly set to False.
Args: Args:
op: An `If` or `While` Operation. op: An `If` or `While` Operation.
lower_using_switch_merge: Explicit value to lower or not (optional).
""" """
if (not _DISABLE_LOWER_USING_SWITCH_MERGE and if lower_using_switch_merge is not None:
# pylint: disable=protected-access
op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=lower_using_switch_merge))
# pylint: enable=protected-access
elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and
not control_flow_util.GraphOrParentsInXlaContext(op.graph) and not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
context.context().function_call_options.executor_type != context.context().function_call_options.executor_type !=
"SINGLE_THREADED_EXECUTOR"): "SINGLE_THREADED_EXECUTOR"):

View File

@ -1140,6 +1140,10 @@ tf_module {
name: "DestroyTemporaryVariable" name: "DestroyTemporaryVariable"
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "DeviceIndex"
argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "Diag" name: "Diag"
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1140,6 +1140,10 @@ tf_module {
name: "DestroyTemporaryVariable" name: "DestroyTemporaryVariable"
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "DeviceIndex"
argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "Diag" name: "Diag"
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "