Add a function to dynamic choose and execute the proper implementation based on underlying device placement
PiperOrigin-RevId: 313605766 Change-Id: I877b684dcef782b375df0504c0250acd9e808ce9
This commit is contained in:
parent
4594b8cfb5
commit
2de551ba87
|
@ -0,0 +1,5 @@
|
||||||
|
op {
|
||||||
|
graph_op_name: "DeviceIndex"
|
||||||
|
visibility: HIDDEN
|
||||||
|
summary: "Return the index of device the op runs."
|
||||||
|
}
|
|
@ -1062,6 +1062,7 @@ cc_library(
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
|
|
|
@ -17,9 +17,12 @@ limitations under the License.
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
|
@ -36,6 +39,11 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
constexpr char kConstOp[] = "Const";
|
||||||
|
constexpr char kCaseOp[] = "Case";
|
||||||
|
constexpr char kDeviceIndexOp[] = "DeviceIndex";
|
||||||
|
|
||||||
|
// TODO(b/157615690): clean up function implementation swap code.
|
||||||
// The overall idea for the function swap is like below:
|
// The overall idea for the function swap is like below:
|
||||||
// ----------- -----------
|
// ----------- -----------
|
||||||
// inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1
|
// inp_1 ->| P_C | -> out_1 g_inp_1 ->| P_C | -> g_out_1
|
||||||
|
@ -292,6 +300,74 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finds the index of the device from the device name list.
|
||||||
|
Status FindDeviceIndex(const utils::MutableNodeView* device_index_node,
|
||||||
|
const string& device, int* index) {
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
if (!DeviceNameUtils::ParseFullName(device, &parsed_name) ||
|
||||||
|
!parsed_name.has_type) {
|
||||||
|
return errors::Internal("Could not parse device name:", device);
|
||||||
|
}
|
||||||
|
const auto& device_list =
|
||||||
|
device_index_node->GetAttr("device_names")->list().s();
|
||||||
|
auto it = absl::c_find(device_list, parsed_name.type);
|
||||||
|
if (it != device_list.end()) {
|
||||||
|
*index = it - device_list.begin();
|
||||||
|
} else {
|
||||||
|
// Sets *index to device_list.size() because the default_fn is guaranteed to
|
||||||
|
// be the final item in the case op branching list.
|
||||||
|
*index = device_list.size();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrites the device_index op to a const op with value of the index.
|
||||||
|
void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
|
||||||
|
int index) {
|
||||||
|
// Modifies the DeviceIndex node to be an Const op with correct device index.
|
||||||
|
auto node = device_index_node->node();
|
||||||
|
node->set_op(kConstOp);
|
||||||
|
node->clear_attr();
|
||||||
|
(*node->mutable_attr())["dtype"].set_type(DT_INT32);
|
||||||
|
auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
|
||||||
|
tensor->set_dtype(DT_INT32);
|
||||||
|
tensor->add_int_val(index);
|
||||||
|
VLOG(2) << "Node after rewriting:" << node->DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const {
|
||||||
|
Status status;
|
||||||
|
VLOG(2) << "graph before rewriting device index:" << graph->DebugString();
|
||||||
|
utils::MutableGraphView graph_view(graph, &status);
|
||||||
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
const int num_nodes = graph_view.NumNodes();
|
||||||
|
for (int k = 0; k < num_nodes; ++k) {
|
||||||
|
auto* node_view = graph_view.GetNode(k);
|
||||||
|
if (node_view->GetOp() != kDeviceIndexOp) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
VLOG(2) << "Found a node to rewrite the device index";
|
||||||
|
|
||||||
|
// Find the case node with device index node as input, rewrite the
|
||||||
|
// DeviceIndex node to have the value of the index of device type of the
|
||||||
|
// case node.
|
||||||
|
for (const auto& fanouts : node_view->GetRegularFanouts()) {
|
||||||
|
for (const auto& fanout : fanouts) {
|
||||||
|
if (fanout.node_view()->GetOp() != kCaseOp) continue;
|
||||||
|
int index;
|
||||||
|
// If any error is thrown out during device parsing, we simply skip
|
||||||
|
// and do not modify the DeviceIndexNode.
|
||||||
|
Status status =
|
||||||
|
FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index);
|
||||||
|
if (status.ok()) {
|
||||||
|
RewriteDeviceIndexOp(node_view, index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
|
Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
|
||||||
if (!graph->has_library()) {
|
if (!graph->has_library()) {
|
||||||
VLOG(2) << "Skipping graph since it does not have function def";
|
VLOG(2) << "Skipping graph since it does not have function def";
|
||||||
|
@ -307,8 +383,9 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const {
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
const int num_nodes = graph_view.NumNodes();
|
const int num_nodes = graph_view.NumNodes();
|
||||||
for (int k = 0; k < num_nodes; ++k)
|
for (int k = 0; k < num_nodes; ++k) {
|
||||||
TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
|
TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph_view.GetNode(k)));
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -326,7 +403,13 @@ Status ImplementationSelector::Optimize(Cluster* cluster,
|
||||||
<< "libraries: " << status;
|
<< "libraries: " << status;
|
||||||
return errors::Aborted("Skipped Optimization");
|
return errors::Aborted("Skipped Optimization");
|
||||||
}
|
}
|
||||||
|
|
||||||
*optimized_graph = item.graph;
|
*optimized_graph = item.graph;
|
||||||
|
status = SelectDeviceIndex(optimized_graph);
|
||||||
|
if (!status.ok()) {
|
||||||
|
*optimized_graph = item.graph;
|
||||||
|
VLOG(2) << "Could not rewrite device index due to error:" << status;
|
||||||
|
}
|
||||||
return SelectImplementation(optimized_graph);
|
return SelectImplementation(optimized_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,28 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
// Motivation: To achieve the same high level functionality, the underlying
|
||||||
|
// implementations sometimes are different for various devices where the
|
||||||
|
// function runs. In order to achieve the correct result and best performance,
|
||||||
|
// the proper implementation needs to be picked dynamically.
|
||||||
|
//
|
||||||
|
// Currently there are two approaches to do this.
|
||||||
|
// (1) Utilize case op and dynamacically change the branch index.
|
||||||
|
// (2) Swap function implementation, it will be deprecated.
|
||||||
|
//
|
||||||
|
// Idea for approach 1.
|
||||||
|
// This transformation rewrites the DeviceIndex op with a Const op with value
|
||||||
|
// of the index of the device the associcated Case op runs.
|
||||||
|
// Example:
|
||||||
|
// def plus_one_gpu(x): return x + 1.0
|
||||||
|
// def plus_one_reference_implementation(x): return x + 1.0
|
||||||
|
// input = tf.constant(2.0, dtype=tf.float32)
|
||||||
|
// cpu_fn = lambda:plus_one_reference_implementation(input)
|
||||||
|
// gpu_fn = lambda:plus_one_gpu(input)
|
||||||
|
// control_flow_ops.execute_fn_for_device(
|
||||||
|
// {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn)
|
||||||
|
//
|
||||||
|
// Idea for approach 2.
|
||||||
// This transformation replaces function calls by the appropriate function
|
// This transformation replaces function calls by the appropriate function
|
||||||
// definition based on properties of the runtime system. For instance,
|
// definition based on properties of the runtime system. For instance,
|
||||||
// we may choose one implementation over another if we have a GPU with
|
// we may choose one implementation over another if we have a GPU with
|
||||||
|
@ -58,7 +80,8 @@ namespace grappler {
|
||||||
// z = plus_one_gpu(input)
|
// z = plus_one_gpu(input)
|
||||||
// print(sess.run(z))
|
// print(sess.run(z))
|
||||||
//
|
//
|
||||||
// At runtime, we will trim either `plus_one_gpu` or
|
|
||||||
|
// At runtime, we will select either `plus_one_gpu` or
|
||||||
// `plus_one_reference_implementation` based on the availability of the GPU.
|
// `plus_one_reference_implementation` based on the availability of the GPU.
|
||||||
//
|
//
|
||||||
// Available annotations:
|
// Available annotations:
|
||||||
|
@ -106,6 +129,68 @@ class ImplementationSelector : public CustomGraphOptimizer {
|
||||||
// gradients.
|
// gradients.
|
||||||
Status SelectImplementation(GraphDef* graph) const;
|
Status SelectImplementation(GraphDef* graph) const;
|
||||||
|
|
||||||
|
// Rewrites the DeviceIndex op with a Const op with value of the index of the
|
||||||
|
// device the associcated Case op runs.
|
||||||
|
|
||||||
|
// This function first looks up all the DeviceIndex ops.
|
||||||
|
// Then for each of these ops, it finds the device of the
|
||||||
|
// associated Case op that takes the DeviceIndex op as the input, and
|
||||||
|
// caculates the index of the device in the device list of DeviceIndex op.
|
||||||
|
// Lastly, it rewrites the DeviceIndex op with a Const op and sets the value
|
||||||
|
// to be the index.
|
||||||
|
//
|
||||||
|
// Example input nodes:
|
||||||
|
// node {
|
||||||
|
// name: "x"
|
||||||
|
// op: "DeviceIndex"
|
||||||
|
// device: "/device:CPU:0"
|
||||||
|
// attr {
|
||||||
|
// key: "device_names"
|
||||||
|
// value {
|
||||||
|
// list {
|
||||||
|
// s: "CPU"
|
||||||
|
// s: "TPU_REPLICATED_CORE"
|
||||||
|
// s: "GPU"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// node {
|
||||||
|
// name: "case"
|
||||||
|
// op: "Case"
|
||||||
|
// input: "x"
|
||||||
|
// device: "/device:GPU:0"
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
// Example output nodes:
|
||||||
|
//
|
||||||
|
// name: "x"
|
||||||
|
// op: "Const"
|
||||||
|
// device: "/device:CPU:0"
|
||||||
|
// attr {
|
||||||
|
// key: "dtype"
|
||||||
|
// value {
|
||||||
|
// type: DT_INT32
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// attr {
|
||||||
|
// key: "value"
|
||||||
|
// value {
|
||||||
|
// tensor {
|
||||||
|
// dtype: DT_INT32
|
||||||
|
// int_val: 2
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// node {
|
||||||
|
// name: "case"
|
||||||
|
// op: "Case"
|
||||||
|
// input: "x"
|
||||||
|
// device: "/device:GPU:0"
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
Status SelectDeviceIndex(GraphDef* graph) const;
|
||||||
|
|
||||||
std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
|
std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector);
|
TF_DISALLOW_COPY_AND_ASSIGN(ImplementationSelector);
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
@ -58,6 +59,167 @@ TEST_F(ImplementationSelectorTest, NoUpdate) {
|
||||||
EXPECT_EQ(item.graph.node_size(), output.node_size());
|
EXPECT_EQ(item.graph.node_size(), output.node_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImplementationSelectorTest, SelectDeviceIndex) {
|
||||||
|
using test::function::NDef;
|
||||||
|
ImplementationSelector optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
GrapplerItem item;
|
||||||
|
AttrValue device_names;
|
||||||
|
device_names.mutable_list()->add_s("CPU");
|
||||||
|
device_names.mutable_list()->add_s("GPU");
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
CpuDevice),
|
||||||
|
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice)});
|
||||||
|
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "x") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImplementationSelectorTest, SelectDeviceIndexMultiOps) {
|
||||||
|
using test::function::NDef;
|
||||||
|
ImplementationSelector optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
GrapplerItem item;
|
||||||
|
AttrValue device_names;
|
||||||
|
device_names.mutable_list()->add_s("CPU");
|
||||||
|
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||||
|
device_names.mutable_list()->add_s("GPU");
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
CpuDevice),
|
||||||
|
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||||
|
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
GpuDevice),
|
||||||
|
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice)});
|
||||||
|
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "x") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
if (node.name() == "y") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImplementationSelectorTest, SelectDeviceIndexNotFound) {
|
||||||
|
using test::function::NDef;
|
||||||
|
ImplementationSelector optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
GrapplerItem item;
|
||||||
|
AttrValue device_names;
|
||||||
|
device_names.mutable_list()->add_s("CPU");
|
||||||
|
device_names.mutable_list()->add_s("GPU");
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
CpuDevice),
|
||||||
|
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, TpuDevice)});
|
||||||
|
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "x") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of device names length.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImplementationSelectorTest, SelectDeviceIndexError) {
|
||||||
|
using test::function::NDef;
|
||||||
|
ImplementationSelector optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
GrapplerItem item;
|
||||||
|
AttrValue device_names;
|
||||||
|
device_names.mutable_list()->add_s("CPU");
|
||||||
|
device_names.mutable_list()->add_s("GPU");
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
CpuDevice),
|
||||||
|
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, "")});
|
||||||
|
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "x") {
|
||||||
|
// Device parse has error, do not rewrite the DeviceIndexNode.
|
||||||
|
EXPECT_EQ("DeviceIndex", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) {
|
||||||
|
using test::function::NDef;
|
||||||
|
ImplementationSelector optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
GrapplerItem item;
|
||||||
|
// DeviceIndex op based implementation selector.
|
||||||
|
AttrValue device_names;
|
||||||
|
device_names.mutable_list()->add_s("CPU");
|
||||||
|
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||||
|
device_names.mutable_list()->add_s("GPU");
|
||||||
|
|
||||||
|
// Function swap based implementation selector.
|
||||||
|
auto cpu_def = test::function::XTimesTwo();
|
||||||
|
auto* func_attr = cpu_def.mutable_attr();
|
||||||
|
(*func_attr)["api_implements"].set_s("times_two");
|
||||||
|
(*func_attr)["api_preferred_device"].set_s("CPU");
|
||||||
|
|
||||||
|
auto gpu_def = test::function::XAddX();
|
||||||
|
auto* func2_attr = gpu_def.mutable_attr();
|
||||||
|
(*func2_attr)["api_implements"].set_s("times_two");
|
||||||
|
(*func2_attr)["api_preferred_device"].set_s("GPU");
|
||||||
|
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
CpuDevice),
|
||||||
|
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||||
|
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
GpuDevice),
|
||||||
|
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice),
|
||||||
|
NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||||
|
NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||||
|
NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
|
||||||
|
NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
|
||||||
|
// FunctionLib
|
||||||
|
{cpu_def, gpu_def});
|
||||||
|
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "x") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
if (node.name() == "y") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
if (node.name() == "y1") {
|
||||||
|
// Make sure the implementation has been swapped to use the GPU version.
|
||||||
|
EXPECT_EQ("XAddX", node.op());
|
||||||
|
} else if (node.name() == "y2") {
|
||||||
|
// Make sure the implementation is not changed.
|
||||||
|
EXPECT_EQ("XTimesTwo", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ImplementationSelectorTest, SwapImplementation) {
|
TEST_F(ImplementationSelectorTest, SwapImplementation) {
|
||||||
using test::function::NDef;
|
using test::function::NDef;
|
||||||
auto cpu_def = test::function::XTimesTwo();
|
auto cpu_def = test::function::XTimesTwo();
|
||||||
|
|
|
@ -924,5 +924,37 @@ class FakeParamOp : public OpKernel {
|
||||||
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
|
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
|
REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_GPU), FakeParamOp);
|
||||||
|
|
||||||
|
// DeviceIndexOP returns the current device index.
|
||||||
|
class DeviceIndexOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
Tensor* device_name_t;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(0, TensorShape({}), &device_name_t));
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
int index = device_names_.size();
|
||||||
|
if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
|
||||||
|
parsed_name.has_type) {
|
||||||
|
auto it = absl::c_find(device_names_, parsed_name.type);
|
||||||
|
if (it != device_names_.end()) {
|
||||||
|
index = it - device_names_.begin();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device_name_t->scalar<int32>()() = index;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
PersistentTensor value_handle_;
|
||||||
|
std::vector<string> device_names_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("DeviceIndex").Device(DEVICE_GPU).HostMemory("index"), DeviceIndexOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -299,4 +299,10 @@ REGISTER_OP("FakeParam")
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Returns the device index.
|
||||||
|
REGISTER_OP("DeviceIndex")
|
||||||
|
.Output("index: int32")
|
||||||
|
.Attr("device_names: list(string)")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
|
@ -942,7 +942,10 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
||||||
return captured_tensor
|
return captured_tensor
|
||||||
|
|
||||||
|
|
||||||
def indexed_case(branch_index, branch_fns, name="indexed_case"):
|
def indexed_case(branch_index,
|
||||||
|
branch_fns,
|
||||||
|
name="indexed_case",
|
||||||
|
lower_using_switch_merge=None):
|
||||||
"""Like conv_v2, except emits a Case op instead of an If."""
|
"""Like conv_v2, except emits a Case op instead of an If."""
|
||||||
if isinstance(branch_index, int):
|
if isinstance(branch_index, int):
|
||||||
raise TypeError("branch_index must not be a Python int", branch_index)
|
raise TypeError("branch_index must not be a Python int", branch_index)
|
||||||
|
@ -976,7 +979,8 @@ def indexed_case(branch_index, branch_fns, name="indexed_case"):
|
||||||
return _build_case(
|
return _build_case(
|
||||||
branch_index,
|
branch_index,
|
||||||
branch_graphs, [g.external_captures for g in branch_graphs],
|
branch_graphs, [g.external_captures for g in branch_graphs],
|
||||||
name=scope)
|
name=scope,
|
||||||
|
lower_using_switch_merge=lower_using_switch_merge)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Case")
|
@ops.RegisterGradient("Case")
|
||||||
|
@ -1064,7 +1068,11 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
||||||
return [None] + outputs
|
return [None] + outputs
|
||||||
|
|
||||||
|
|
||||||
def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
def _build_case(branch_index,
|
||||||
|
branch_graphs,
|
||||||
|
branch_inputs,
|
||||||
|
name=None,
|
||||||
|
lower_using_switch_merge=None):
|
||||||
"""Creates an `Case` op from `branch_index`, branch graphs and inputs.
|
"""Creates an `Case` op from `branch_index`, branch graphs and inputs.
|
||||||
|
|
||||||
Note that this modifies `branch_graphs` to make the inputs match, and to
|
Note that this modifies `branch_graphs` to make the inputs match, and to
|
||||||
|
@ -1080,6 +1088,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
||||||
branch_inputs: List of lists of Tensors to be passed to corresponding
|
branch_inputs: List of lists of Tensors to be passed to corresponding
|
||||||
branch_graph as input.
|
branch_graph as input.
|
||||||
name: the name for the Case op.
|
name: the name for the Case op.
|
||||||
|
lower_using_switch_merge: Lower this op using switch merge ops (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of Tensors which are the outputs of the Case op. Does not include
|
A list of Tensors which are the outputs of the Case op. Does not include
|
||||||
|
@ -1105,7 +1114,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
||||||
case_op, tensors = _get_op_and_outputs(tensors)
|
case_op, tensors = _get_op_and_outputs(tensors)
|
||||||
|
|
||||||
if case_op is not None:
|
if case_op is not None:
|
||||||
util.maybe_set_lowering_attr(case_op)
|
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
|
||||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||||
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
|
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
|
||||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||||
|
|
|
@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_util as util
|
from tensorflow.python.ops import control_flow_util as util
|
||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
from tensorflow.python.ops import gen_control_flow_ops
|
from tensorflow.python.ops import gen_control_flow_ops
|
||||||
|
from tensorflow.python.ops import gen_functional_ops
|
||||||
from tensorflow.python.ops import gen_logging_ops
|
from tensorflow.python.ops import gen_logging_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
@ -3283,7 +3284,11 @@ def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
def _indexed_case_helper(branch_fns, default, branch_index, name):
|
def _indexed_case_helper(branch_fns,
|
||||||
|
default,
|
||||||
|
branch_index,
|
||||||
|
name,
|
||||||
|
lower_using_switch_merge=None):
|
||||||
"""Implementation of case that emits the n-way indexed Case op.
|
"""Implementation of case that emits the n-way indexed Case op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -3293,6 +3298,7 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
|
||||||
branch_index: Optional int `Tensor`, which selects for the corresponding
|
branch_index: Optional int `Tensor`, which selects for the corresponding
|
||||||
pred_fn_pair.
|
pred_fn_pair.
|
||||||
name: A name for this operation (optional).
|
name: A name for this operation (optional).
|
||||||
|
lower_using_switch_merge: Lower this op using switch merge ops (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The tensors returned by the pair whose key matched branch_index, or
|
The tensors returned by the pair whose key matched branch_index, or
|
||||||
|
@ -3314,7 +3320,10 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
|
||||||
| math_ops.greater_equal(branch_index, len(branch_fns)),
|
| math_ops.greater_equal(branch_index, len(branch_fns)),
|
||||||
len(branch_fns) - 1, branch_index)
|
len(branch_fns) - 1, branch_index)
|
||||||
return branch_fns[int(branch_index)]()
|
return branch_fns[int(branch_index)]()
|
||||||
return cond_v2.indexed_case(branch_index, branch_fns)
|
return cond_v2.indexed_case(
|
||||||
|
branch_index,
|
||||||
|
branch_fns,
|
||||||
|
lower_using_switch_merge=lower_using_switch_merge)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("case", v1=[])
|
@tf_export("case", v1=[])
|
||||||
|
@ -3607,6 +3616,50 @@ def switch_case(branch_index,
|
||||||
return _indexed_case_helper(branch_fns, default, branch_index, name)
|
return _indexed_case_helper(branch_fns, default, branch_index, name)
|
||||||
|
|
||||||
|
|
||||||
|
def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
|
||||||
|
"""Executes one of the provided callables based on the device placement.
|
||||||
|
|
||||||
|
This API is used when the implementations for high level function depend on
|
||||||
|
the underlying device placement. It takes a dictionary of device type to
|
||||||
|
callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
|
||||||
|
the device where to run this op matches the key in 'device_branch_fns',
|
||||||
|
the corresponding callable is executed, falling back to 'default_fn' if none
|
||||||
|
matches.
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
def f1(): return tf.constant(1)
|
||||||
|
def f2(): return tf.constant(2)
|
||||||
|
r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
|
||||||
|
```
|
||||||
|
'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
|
||||||
|
any other device types.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_branch_fns: a dictionary of device types to the callables. Each
|
||||||
|
callable must return a matching structure of tensors.
|
||||||
|
default_fn: fallback callable when the underlying device does not match any
|
||||||
|
key in the 'device_branch_fns'.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tensors returned by the callable identified by device type during
|
||||||
|
execution, or those returned by 'default_fn' if no key matches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
|
||||||
|
branch_fns = list(device_branch_fns_upper.values())
|
||||||
|
devices = list(device_branch_fns_upper.keys())
|
||||||
|
device_index = gen_functional_ops.device_index(device_names=devices)
|
||||||
|
return _indexed_case_helper(
|
||||||
|
branch_fns,
|
||||||
|
default_fn,
|
||||||
|
device_index,
|
||||||
|
name,
|
||||||
|
lower_using_switch_merge=False)
|
||||||
|
|
||||||
|
|
||||||
class XLAControlFlowContext(ControlFlowContext):
|
class XLAControlFlowContext(ControlFlowContext):
|
||||||
"""Base class for XLA and TPU control flow contexts."""
|
"""Base class for XLA and TPU control flow contexts."""
|
||||||
|
|
||||||
|
|
|
@ -1211,6 +1211,92 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
control_flow_ops.switch_case(array_ops.constant(1), branches)
|
control_flow_ops.switch_case(array_ops.constant(1), branches)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def testCommonCases(self):
|
||||||
|
|
||||||
|
def cpu_fn(x):
|
||||||
|
return x + x
|
||||||
|
|
||||||
|
def gpu_fn(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
def flexible_fn(a):
|
||||||
|
branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)}
|
||||||
|
return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a))
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def flexible_defun(a):
|
||||||
|
return flexible_fn(a)
|
||||||
|
|
||||||
|
def run_defun_and_tape(a):
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(a)
|
||||||
|
result = flexible_defun(a)
|
||||||
|
grad = tape.gradient(result, a)
|
||||||
|
r = flexible_fn(a)
|
||||||
|
return r, result, grad
|
||||||
|
|
||||||
|
a = array_ops.constant(3.)
|
||||||
|
with ops.device("cpu:0"):
|
||||||
|
r, result, grad = run_defun_and_tape(a)
|
||||||
|
self.assertEqual(6., self.evaluate(r))
|
||||||
|
self.assertEqual(6., self.evaluate(result))
|
||||||
|
self.assertEqual([2.], self.evaluate(grad))
|
||||||
|
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
with ops.device("gpu:0"):
|
||||||
|
r, result, grad = run_defun_and_tape(a)
|
||||||
|
self.assertEqual(9., self.evaluate(r))
|
||||||
|
self.assertEqual(9., self.evaluate(result))
|
||||||
|
self.assertEqual([6.], self.evaluate(grad))
|
||||||
|
|
||||||
|
# no device annotation
|
||||||
|
r, result, grad = run_defun_and_tape(a)
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
self.assertEqual(9., self.evaluate(r))
|
||||||
|
self.assertEqual(9., self.evaluate(result))
|
||||||
|
self.assertEqual([6.], self.evaluate(grad))
|
||||||
|
else:
|
||||||
|
self.assertEqual(6., self.evaluate(r))
|
||||||
|
self.assertEqual(6., self.evaluate(result))
|
||||||
|
self.assertEqual([2.], self.evaluate(grad))
|
||||||
|
|
||||||
|
def testFallBack(self):
|
||||||
|
|
||||||
|
def default_fn(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def tpu_fn(x):
|
||||||
|
return x * x * x
|
||||||
|
|
||||||
|
def flexible_fn(a):
|
||||||
|
branches = {"TPU": lambda: tpu_fn(a)}
|
||||||
|
return control_flow_ops.execute_fn_for_device(
|
||||||
|
branches, default_fn=lambda: default_fn(a))
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def flexible_defun(a):
|
||||||
|
return flexible_fn(a)
|
||||||
|
|
||||||
|
a = array_ops.constant(3.)
|
||||||
|
with ops.device("cpu:0"):
|
||||||
|
result_defun = flexible_defun(a)
|
||||||
|
result_defun = flexible_fn(a)
|
||||||
|
self.assertEqual(3., self.evaluate(result_defun))
|
||||||
|
# execute_fn_for_device is not inside defun_function.
|
||||||
|
result = flexible_fn(a)
|
||||||
|
self.assertEqual(3., self.evaluate(result))
|
||||||
|
|
||||||
|
if test_util.is_gpu_available():
|
||||||
|
with ops.device("gpu:0"):
|
||||||
|
result_defun = flexible_defun(a)
|
||||||
|
self.assertEqual(3., self.evaluate(result_defun))
|
||||||
|
# execute_fn_for_device is not inside defun_function.
|
||||||
|
result = flexible_fn(a)
|
||||||
|
self.assertEqual(3., self.evaluate(result))
|
||||||
|
|
||||||
|
|
||||||
class CaseTest(test_util.TensorFlowTestCase):
|
class CaseTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
|
|
@ -92,7 +92,7 @@ def unique_grad_fn_name(forward_name):
|
||||||
return "%s_grad_%s" % (forward_name, ops.uid())
|
return "%s_grad_%s" % (forward_name, ops.uid())
|
||||||
|
|
||||||
|
|
||||||
def maybe_set_lowering_attr(op):
|
def maybe_set_lowering_attr(op, lower_using_switch_merge=None):
|
||||||
"""Sets the flag to enable lowering on `op` if necessary.
|
"""Sets the flag to enable lowering on `op` if necessary.
|
||||||
|
|
||||||
Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
|
Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
|
||||||
|
@ -108,11 +108,18 @@ def maybe_set_lowering_attr(op):
|
||||||
- When the eager execution context specifies the executor of functions to
|
- When the eager execution context specifies the executor of functions to
|
||||||
be the single threaded executor (see context.function_executor_type()).
|
be the single threaded executor (see context.function_executor_type()).
|
||||||
Because the single threaded executor does not support v1 control flow ops.
|
Because the single threaded executor does not support v1 control flow ops.
|
||||||
|
- When 'lower_using_switch_merge' is explicitly set to False.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op: An `If` or `While` Operation.
|
op: An `If` or `While` Operation.
|
||||||
|
lower_using_switch_merge: Explicit value to lower or not (optional).
|
||||||
"""
|
"""
|
||||||
if (not _DISABLE_LOWER_USING_SWITCH_MERGE and
|
if lower_using_switch_merge is not None:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
op._set_attr("_lower_using_switch_merge",
|
||||||
|
attr_value_pb2.AttrValue(b=lower_using_switch_merge))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and
|
||||||
not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
||||||
context.context().function_call_options.executor_type !=
|
context.context().function_call_options.executor_type !=
|
||||||
"SINGLE_THREADED_EXECUTOR"):
|
"SINGLE_THREADED_EXECUTOR"):
|
||||||
|
|
|
@ -1140,6 +1140,10 @@ tf_module {
|
||||||
name: "DestroyTemporaryVariable"
|
name: "DestroyTemporaryVariable"
|
||||||
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DeviceIndex"
|
||||||
|
argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Diag"
|
name: "Diag"
|
||||||
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|
|
@ -1140,6 +1140,10 @@ tf_module {
|
||||||
name: "DestroyTemporaryVariable"
|
name: "DestroyTemporaryVariable"
|
||||||
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'ref\', \'var_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DeviceIndex"
|
||||||
|
argspec: "args=[\'device_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Diag"
|
name: "Diag"
|
||||||
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
|
Loading…
Reference in New Issue