Add a function to dynamic choose and execute the proper implementation based on underlying device placement
PiperOrigin-RevId: 313605766 Change-Id: I877b684dcef782b375df0504c0250acd9e808ce9
This commit is contained in:
parent
4594b8cfb5
commit
2de551ba87
|
@ -0,0 +1,5 @@
|
|||
op {
|
||||
graph_op_name: "DeviceIndex"
|
||||
visibility: HIDDEN
|
||||
summary: "Return the index of device the op runs."
|
||||
}
|
|
@ -1062,6 +1062,7 @@ cc_library(
|
|||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
|
|
|
@ -17,9 +17,12 @@ limitations under the License.
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
|
@ -36,6 +39,11 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
constexpr char kConstOp[] = "Const";
|
||||
constexpr char kCaseOp[] = "Case";
|
||||
constexpr char kDeviceIndexOp[] = "DeviceIndex";
|
||||
|
||||
// TODO(b/157615690): clean up function implementation swap code.
|
||||
// The overall idea for the function swap is like below:
|
||||
// ----------- -----------
|
||||
// inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1
|
||||
|
@ -292,6 +300,74 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Finds the index of the device from the device name list.
|
||||
Status FindDeviceIndex(const utils::MutableNodeView* device_index_node,
|
||||
const string& device, int* index) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!DeviceNameUtils::ParseFullName(device, &parsed_name) ||
|
||||
!parsed_name.has_type) {
|
||||
return errors::Internal("Could not parse device name:", device);
|
||||
}
|
||||
const auto& device_list =
|
||||
device_index_node->GetAttr("device_names")->list().s();
|
||||
auto it = absl::c_find(device_list, parsed_name.type);
|
||||
if (it != device_list.end()) {
|
||||
*index = it - device_list.begin();
|
||||
} else {
|
||||
// Sets *index to device_list.size() because the default_fn is guaranteed to
|
||||
// be the final item in the case op branching list.
|
||||
*index = device_list.size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Rewrites the device_index op to a const op with value of the index.
|
||||
void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
|
||||
int index) {
|
||||
// Modifies the DeviceIndex node to be an Const op with correct device index.
|
||||
auto node = device_index_node->node();
|
||||
node->set_op(kConstOp);
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["dtype"].set_type(DT_INT32);
|
||||
auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
|
||||
tensor->set_dtype(DT_INT32);
|
||||
tensor->add_int_val(index);
|
||||
VLOG(2) << "Node after rewriting:" << node->DebugString();
|
||||
}
|
||||
|
||||
Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
|
||||
Status status;
|
||||
VLOG(2) << "graph before rewriting device index:" << graph->DebugString();
|
||||
utils::MutableGraphView graph_view(graph, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
const int num_nodes = graph_view.NumNodes();
|
||||
for (int k = 0; k < num_nodes; ++k) {
|
||||
auto* node_view = graph_view.GetNode(k);
|
||||
if (node_view->GetOp() != kDeviceIndexOp) {
|
||||
continue;
|
||||
}
|
||||
VLOG(2) << "Found a node to rewrite the device index";
|
||||
|
||||
// Find the case node with device index node as input, rewrite the
|
||||
// DeviceIndex node to have the value of the index of device type of the
|
||||
// case node.
|
||||
for (const auto& fanouts : node_view->GetRegularFanouts()) {
|
||||
for (const auto& fanout : fanouts) {
|
||||
if (fanout.node_view()->GetOp() != kCaseOp) continue;
|
||||
int index;
|
||||
// If any error is thrown out during device parsing, we simply skip
|
||||
// and do not modify the DeviceIndexNode.
|
||||
Status status =
|
||||
FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index);
|
||||
if (status.ok()) {
|
||||
RewriteDeviceIndexOp(node_view, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
|
||||
if (!graph->has_library()) {
|
||||
VLOG(2) << "Skipping graph since it does not have function def";
|
||||
|
@ -307,8 +383,9 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
|
|||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
const int num_nodes = graph_view.NumNodes();
|
||||
for (int k = 0; k < num_nodes; ++k)
|
||||
for (int k = 0; k < num_nodes; ++k) {
|
||||
TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -326,7 +403,13 @@ Status ImplementationSelector::Optimize(Cluster* cluster,
|
|||
<< "libraries: " << status;
|
||||
return errors::Aborted("Skipped Optimization");
|
||||
}
|
||||
|
||||
*optimized_graph = item.graph;
|
||||
status = SelectDeviceIndex(optimized_graph);
|
||||
if (!status.ok()) {
|
||||
*optimized_graph = item.graph;
|
||||
VLOG(2) << "Could not rewrite device index due to error:" << status;
|
||||
}
|
||||
return SelectImplementation(optimized_graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -34,6 +34,28 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Motivation: To achieve the same high level functionality, the underlying
|
||||
// implementations sometimes are different for various devices where the
|
||||
// function runs. In order to achieve the correct result and best performance,
|
||||
// the proper implementation needs to be picked dynamically.
|
||||
//
|
||||
// Currently there are two approaches to do this.
|
||||
// (1) Utilize case op and dynamacically change the branch index.
|
||||
// (2) Swap function implementation, it will be deprecated.
|
||||
//
|
||||
// Idea for approach 1.
|
||||
// This transformation rewrites the DeviceIndex op with a Const op with value
|
||||
// of the index of the device the associcated Case op runs.
|
||||
// Example:
|
||||
// def plus_one_gpu(x): return x + 1.0
|
||||
// def plus_one_reference_implementation(x): return x + 1.0
|
||||
// input = tf.constant(2.0, dtype=tf.float32)
|
||||
// cpu_fn = lambda:plus_one_reference_implementation(input)
|
||||
// gpu_fn = lambda:plus_one_gpu(input)
|
||||
// control_flow_ops.execute_fn_for_device(
|
||||
// {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn)
|
||||
//
|
||||
// Idea for approach 2.
|
||||
// This transformation replaces function calls by the appropriate function
|
||||
// definition based on properties of the runtime system. For instance,
|
||||
// we may choose one implementation over another if we have a GPU with
|
||||
|
@ -58,7 +80,8 @@ namespace grappler {
|
|||
// z = plus_one_gpu(input)
|
||||
// print(sess.run(z))
|
||||
//
|
||||
// At runtime, we will trim either `plus_one_gpu` or
|
||||
|
||||
// At runtime, we will select either `plus_one_gpu` or
|
||||
// `plus_one_reference_implementation` based on the availability of the GPU.
|
||||
//
|
||||
// Available annotations:
|
||||
|
@ -106,6 +129,68 @@ class ImplementationSelector : public CustomGraphOptimizer {
|
|||
// gradients.
|
||||
Status SelectImplementation(GraphDef* graph) const;
|
||||
|
||||
// Rewrites the DeviceIndex op with a Const op with value of the index of the
|
||||
// device the associcated Case op runs.
|
||||
|
||||
// This function first looks up all the DeviceIndex ops.
|
||||
// Then for each of these ops, it finds the device of the
|
||||
// associated Case op that takes the DeviceIndex op as the input, and
|
||||
// caculates the index of the device in the device list of DeviceIndex op.
|
||||
// Lastly, it rewrites the DeviceIndex op with a Const op and sets the value
|
||||
// to be the index.
|
||||
//
|
||||
// Example input nodes:
|
||||
// node {
|
||||
// name: "x"
|
||||
// op: "DeviceIndex"
|
||||
// device: "/device:CPU:0"
|
||||
// attr {
|
||||
// key: "device_names"
|
||||
// value {
|
||||
// list {
|
||||
// s: "CPU"
|
||||
// s: "TPU_REPLICATED_CORE"
|
||||
// s: "GPU"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// node {
|
||||
// name: "case"
|
||||
// op: "Case"
|
||||
// input: "x"
|
||||
// device: "/device:GPU:0"
|
||||
// ...
|
||||
// }
|
||||
// Example output nodes:
|
||||
//
|
||||
// name: "x"
|
||||
// op: "Const"
|
||||
// device: "/device:CPU:0"
|
||||
// attr {
|
||||
// key: "dtype"
|
||||
// value {
|
||||
// type: DT_INT32
|
||||
// }
|
||||
// }
|
||||
// attr {
|
||||
// key: "value"
|
||||
// value {
|
||||
// tensor {
|
||||
// dtype: DT_INT32
|
||||
// int_val: 2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// node {
|
||||
// name: "case"
|
||||
// op: "Case"
|
||||
// input: "x"
|
||||
// device: "/device:GPU:0"
|
||||
// ...
|
||||
// }
|
||||
Status SelectDeviceIndex(GraphDef* graph) const;
|
||||
|
||||
std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector);
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
|
@ -58,6 +59,167 @@ TEST_F(ImplementationSelectorTest, NoUpdate) {
|
|||
EXPECT_EQ(item.graph.node_size(), output.node_size());
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndex) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
GpuDevice),
|
||||
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice)});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexNotFound) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, TpuDevice)});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of device names length.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SelectDeviceIndexError) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, "")});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Device parse has error, do not rewrite the DeviceIndexNode.
|
||||
EXPECT_EQ("DeviceIndex", node.op());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
// DeviceIndex op based implementation selector.
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
|
||||
// Function swap based implementation selector.
|
||||
auto cpu_def = test::function::XTimesTwo();
|
||||
auto* func_attr = cpu_def.mutable_attr();
|
||||
(*func_attr)["api_implements"].set_s("times_two");
|
||||
(*func_attr)["api_preferred_device"].set_s("CPU");
|
||||
|
||||
auto gpu_def = test::function::XAddX();
|
||||
auto* func2_attr = gpu_def.mutable_attr();
|
||||
(*func2_attr)["api_implements"].set_s("times_two");
|
||||
(*func2_attr)["api_preferred_device"].set_s("GPU");
|
||||
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
GpuDevice),
|
||||
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice),
|
||||
NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
|
||||
NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
|
||||
// FunctionLib
|
||||
{cpu_def, gpu_def});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y1") {
|
||||
// Make sure the implementation has been swapped to use the GPU version.
|
||||
EXPECT_EQ("XAddX", node.op());
|
||||
} else if (node.name() == "y2") {
|
||||
// Make sure the implementation is not changed.
|
||||
EXPECT_EQ("XTimesTwo", node.op());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SwapImplementation) {
|
||||
using test::function::NDef;
|
||||
auto cpu_def = test::function::XTimesTwo();
|
||||
|
|
|
@ -924,5 +924,37 @@ class FakeParamOp : public OpKernel {
|
|||
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
|
||||
|
||||
// DeviceIndexOP returns the current device index.
|
||||
class DeviceIndexOp : public OpKernel {
|
||||
public:
|
||||
explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor* device_name_t;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({}), &device_name_t));
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
int index = device_names_.size();
|
||||
if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
|
||||
parsed_name.has_type) {
|
||||
auto it = absl::c_find(device_names_, parsed_name.type);
|
||||
if (it != device_names_.end()) {
|
||||
index = it - device_names_.begin();
|
||||
}
|
||||
}
|
||||
device_name_t->scalar<int32>()() = index;
|
||||
}
|
||||
|
||||
private:
|
||||
PersistentTensor value_handle_;
|
||||
std::vector<string> device_names_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -299,4 +299,10 @@ REGISTER_OP("FakeParam")
|
|||
return Status::OK();
|
||||
});
|
||||
|
||||
// Returns the device index.
|
||||
REGISTER_OP("DeviceIndex")
|
||||
.Output("index: int32")
|
||||
.Attr("device_names: list(string)")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
|
|
@ -942,7 +942,10 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
|||
return captured_tensor
|
||||
|
||||
|
||||
def indexed_case(branch_index, branch_fns, name="indexed_case"):
|
||||
def indexed_case(branch_index,
|
||||
branch_fns,
|
||||
name="indexed_case",
|
||||
lower_using_switch_merge=None):
|
||||
"""Like conv_v2, except emits a Case op instead of an If."""
|
||||
if isinstance(branch_index, int):
|
||||
raise TypeError("branch_index must not be a Python int", branch_index)
|
||||
|
@ -976,7 +979,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"):
|
|||
return _build_case(
|
||||
branch_index,
|
||||
branch_graphs, [g.external_captures for g in branch_graphs],
|
||||
name=scope)
|
||||
name=scope,
|
||||
lower_using_switch_merge=lower_using_switch_merge)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Case")
|
||||
|
@ -1064,7 +1068,11 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
|||
return [None] + outputs
|
||||
|
||||
|
||||
def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
||||
def _build_case(branch_index,
|
||||
branch_graphs,
|
||||
branch_inputs,
|
||||
name=None,
|
||||
lower_using_switch_merge=None):
|
||||
"""Creates an `Case` op from `branch_index`, branch graphs and inputs.
|
||||
|
||||
Note that this modifies `branch_graphs` to make the inputs match, and to
|
||||
|
@ -1080,6 +1088,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
|||
branch_inputs: List of lists of Tensors to be passed to corresponding
|
||||
branch_graph as input.
|
||||
name: the name for the Case op.
|
||||
lower_using_switch_merge: Lower this op using switch merge ops (optional).
|
||||
|
||||
Returns:
|
||||
A list of Tensors which are the outputs of the Case op. Does not include
|
||||
|
@ -1105,7 +1114,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
|||
case_op, tensors = _get_op_and_outputs(tensors)
|
||||
|
||||
if case_op is not None:
|
||||
util.maybe_set_lowering_attr(case_op)
|
||||
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
|
|
|
@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops
|
|||
from tensorflow.python.ops import control_flow_util as util
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_control_flow_ops
|
||||
from tensorflow.python.ops import gen_functional_ops
|
||||
from tensorflow.python.ops import gen_logging_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
@ -3283,7 +3284,11 @@ def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
|
|||
return actions
|
||||
|
||||
|
||||
def _indexed_case_helper(branch_fns, default, branch_index, name):
|
||||
def _indexed_case_helper(branch_fns,
|
||||
default,
|
||||
branch_index,
|
||||
name,
|
||||
lower_using_switch_merge=None):
|
||||
"""Implementation of case that emits the n-way indexed Case op.
|
||||
|
||||
Args:
|
||||
|
@ -3293,6 +3298,7 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
|
|||
branch_index: Optional int `Tensor`, which selects for the corresponding
|
||||
pred_fn_pair.
|
||||
name: A name for this operation (optional).
|
||||
lower_using_switch_merge: Lower this op using switch merge ops (optional).
|
||||
|
||||
Returns:
|
||||
The tensors returned by the pair whose key matched branch_index, or
|
||||
|
@ -3314,7 +3320,10 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
|
|||
| math_ops.greater_equal(branch_index, len(branch_fns)),
|
||||
len(branch_fns) - 1, branch_index)
|
||||
return branch_fns[int(branch_index)]()
|
||||
return cond_v2.indexed_case(branch_index, branch_fns)
|
||||
return cond_v2.indexed_case(
|
||||
branch_index,
|
||||
branch_fns,
|
||||
lower_using_switch_merge=lower_using_switch_merge)
|
||||
|
||||
|
||||
@tf_export("case", v1=[])
|
||||
|
@ -3607,6 +3616,50 @@ def switch_case(branch_index,
|
|||
return _indexed_case_helper(branch_fns, default, branch_index, name)
|
||||
|
||||
|
||||
def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
|
||||
"""Executes one of the provided callables based on the device placement.
|
||||
|
||||
This API is used when the implementations for high level function depend on
|
||||
the underlying device placement. It takes a dictionary of device type to
|
||||
callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
|
||||
the device where to run this op matches the key in 'device_branch_fns',
|
||||
the corresponding callable is executed, falling back to 'default_fn' if none
|
||||
matches.
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def f1(): return tf.constant(1)
|
||||
def f2(): return tf.constant(2)
|
||||
r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
|
||||
```
|
||||
'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
|
||||
any other device types.
|
||||
|
||||
|
||||
Args:
|
||||
device_branch_fns: a dictionary of device types to the callables. Each
|
||||
callable must return a matching structure of tensors.
|
||||
default_fn: fallback callable when the underlying device does not match any
|
||||
key in the 'device_branch_fns'.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
The tensors returned by the callable identified by device type during
|
||||
execution, or those returned by 'default_fn' if no key matches.
|
||||
"""
|
||||
|
||||
device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
|
||||
branch_fns = list(device_branch_fns_upper.values())
|
||||
devices = list(device_branch_fns_upper.keys())
|
||||
device_index = gen_functional_ops.device_index(device_names=devices)
|
||||
return _indexed_case_helper(
|
||||
branch_fns,
|
||||
default_fn,
|
||||
device_index,
|
||||
name,
|
||||
lower_using_switch_merge=False)
|
||||
|
||||
|
||||
class XLAControlFlowContext(ControlFlowContext):
|
||||
"""Base class for XLA and TPU control flow contexts."""
|
||||
|
||||
|
|
|
@ -1211,6 +1211,92 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
control_flow_ops.switch_case(array_ops.constant(1), branches)
|
||||
|
||||
|
||||
class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testCommonCases(self):
|
||||
|
||||
def cpu_fn(x):
|
||||
return x + x
|
||||
|
||||
def gpu_fn(x):
|
||||
return x * x
|
||||
|
||||
def flexible_fn(a):
|
||||
branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)}
|
||||
return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a))
|
||||
|
||||
@def_function.function
|
||||
def flexible_defun(a):
|
||||
return flexible_fn(a)
|
||||
|
||||
def run_defun_and_tape(a):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(a)
|
||||
result = flexible_defun(a)
|
||||
grad = tape.gradient(result, a)
|
||||
r = flexible_fn(a)
|
||||
return r, result, grad
|
||||
|
||||
a = array_ops.constant(3.)
|
||||
with ops.device("cpu:0"):
|
||||
r, result, grad = run_defun_and_tape(a)
|
||||
self.assertEqual(6., self.evaluate(r))
|
||||
self.assertEqual(6., self.evaluate(result))
|
||||
self.assertEqual([2.], self.evaluate(grad))
|
||||
|
||||
if test_util.is_gpu_available():
|
||||
with ops.device("gpu:0"):
|
||||
r, result, grad = run_defun_and_tape(a)
|
||||
self.assertEqual(9., self.evaluate(r))
|
||||
self.assertEqual(9., self.evaluate(result))
|
||||
self.assertEqual([6.], self.evaluate(grad))
|
||||
|
||||
# no device annotation
|
||||
r, result, grad = run_defun_and_tape(a)
|
||||
if test_util.is_gpu_available():
|
||||
self.assertEqual(9., self.evaluate(r))
|
||||
self.assertEqual(9., self.evaluate(result))
|
||||
self.assertEqual([6.], self.evaluate(grad))
|
||||
else:
|
||||
self.assertEqual(6., self.evaluate(r))
|
||||
self.assertEqual(6., self.evaluate(result))
|
||||
self.assertEqual([2.], self.evaluate(grad))
|
||||
|
||||
def testFallBack(self):
|
||||
|
||||
def default_fn(x):
|
||||
return x
|
||||
|
||||
def tpu_fn(x):
|
||||
return x * x * x
|
||||
|
||||
def flexible_fn(a):
|
||||
branches = {"TPU": lambda: tpu_fn(a)}
|
||||
return control_flow_ops.execute_fn_for_device(
|
||||
branches, default_fn=lambda: default_fn(a))
|
||||
|
||||
@def_function.function
|
||||
def flexible_defun(a):
|
||||
return flexible_fn(a)
|
||||
|
||||
a = array_ops.constant(3.)
|
||||
with ops.device("cpu:0"):
|
||||
result_defun = flexible_defun(a)
|
||||
result_defun = flexible_fn(a)
|
||||
self.assertEqual(3., self.evaluate(result_defun))
|
||||
# execute_fn_for_device is not inside defun_function.
|
||||
result = flexible_fn(a)
|
||||
self.assertEqual(3., self.evaluate(result))
|
||||
|
||||
if test_util.is_gpu_available():
|
||||
with ops.device("gpu:0"):
|
||||
result_defun = flexible_defun(a)
|
||||
self.assertEqual(3., self.evaluate(result_defun))
|
||||
# execute_fn_for_device is not inside defun_function.
|
||||
result = flexible_fn(a)
|
||||
self.assertEqual(3., self.evaluate(result))
|
||||
|
||||
|
||||
class CaseTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
|
|
@ -92,7 +92,7 @@ def unique_grad_fn_name(forward_name):
|
|||
return "%s_grad_%s" % (forward_name, ops.uid())
|
||||
|
||||
|
||||
def maybe_set_lowering_attr(op):
|
||||
def maybe_set_lowering_attr(op, lower_using_switch_merge=None):
|
||||
"""Sets the flag to enable lowering on `op` if necessary.
|
||||
|
||||
Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
|
||||
|
@ -108,14 +108,21 @@ def maybe_set_lowering_attr(op):
|
|||
- When the eager execution context specifies the executor of functions to
|
||||
be the single threaded executor (see context.function_executor_type()).
|
||||
Because the single threaded executor does not support v1 control flow ops.
|
||||
- When 'lower_using_switch_merge' is explicitly set to False.
|
||||
|
||||
Args:
|
||||
op: An `If` or `While` Operation.
|
||||
lower_using_switch_merge: Explicit value to lower or not (optional).
|
||||
"""
|
||||
if (not _DISABLE_LOWER_USING_SWITCH_MERGE and
|
||||
not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
||||
context.context().function_call_options.executor_type !=
|
||||
"SINGLE_THREADED_EXECUTOR"):
|
||||
if lower_using_switch_merge is not None:
|
||||
# pylint: disable=protected-access
|
||||
op._set_attr("_lower_using_switch_merge",
|
||||
attr_value_pb2.AttrValue(b=lower_using_switch_merge))
|
||||
# pylint: enable=protected-access
|
||||
elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and
|
||||
not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
||||
context.context().function_call_options.executor_type !=
|
||||
"SINGLE_THREADED_EXECUTOR"):
|
||||
# pylint: disable=protected-access
|
||||
op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
|
||||
# pylint: enable=protected-access
|
||||
|
|
|
@ -1140,6 +1140,10 @@ tf_module {
|
|||
name: "DestroyTemporaryVariable"
|
||||
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeviceIndex"
|
||||
argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Diag"
|
||||
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
@ -1140,6 +1140,10 @@ tf_module {
|
|||
name: "DestroyTemporaryVariable"
|
||||
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeviceIndex"
|
||||
argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Diag"
|
||||
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
Loading…
Reference in New Issue