Fix requested_device of NextIteration node.

PiperOrigin-RevId: 310629600
Change-Id: Ic097448918a59b7ae42683a5f8d2014f97e22447
This commit is contained in:
Saurabh Saxena 2020-05-08 14:29:27 -07:00 committed by TensorFlower Gardener
parent 0c5b1b8ab2
commit f596266023
3 changed files with 151 additions and 9 deletions
tensorflow/core/common_runtime

View File

@ -2511,10 +2511,12 @@ tf_cc_test(
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)

View File

@ -444,15 +444,14 @@ Status LowerWhileHelper::CreateNextIterationNodes() {
if (IsResource(i)) {
continue;
}
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("next_iteration"), "NextIteration",
graph_->op_registry(), &debug_info_)
.Input(NodeOut(body_call_node_, i))
.ControlInput(body_call_node_)
.Device(while_op_->requested_device())
.AssignedDevice(merge_nodes_[op_input_output_to_lowered_node_[i]]
->assigned_device_name())
.Finalize(graph_, &next_iteration));
Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
graph_->op_registry(), &debug_info_)
.Input(NodeOut(body_call_node_, i))
.ControlInput(body_call_node_)
.Device(merge_node->requested_device())
.AssignedDevice(merge_node->assigned_device_name())
.Finalize(graph_, &next_iteration));
next_iterations_nodes_.emplace_back(next_iteration);
}
return Status::OK();

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/match.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@ -25,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -262,6 +264,145 @@ TEST(LowerWhileOpTest, ForwardAssignedInputDevice) {
ASSERT_EQ(exit_consumers, 1);
}
TEST(LowerWhileOpTest, ForwardRequestedInputDevice) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
// Add test functions for cond and body.
FunctionDefLibrary f_lib_proto;
*f_lib_proto.add_function() = test::function::XTimesTwo();
*f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto));
auto type = DT_FLOAT;
// We will place the loop var on the gpu:0.
const string gpu_0_device = "/job:localhost/replica:0/task:0/gpu:0";
// We will place loop's control input on the gpu:1.
const string gpu_1_device = "/job:localhost/replica:0/task:0/gpu:1";
// We will place While op on gpu:2.
const string gpu_2_device = "/job:localhost/replica:0/task:0/gpu:2";
Node* gpu_0_ph;
TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder")
.Attr("dtype", type)
.Device(gpu_0_device)
.Finalize(graph.get(), &gpu_0_ph));
Node* control_in;
// Add a control input to the While op to trigger the creation of a
// LoopExecuted node.
TF_CHECK_OK(NodeBuilder("control_in", "Placeholder")
.Attr("dtype", type)
.Device(gpu_1_device)
.Finalize(graph.get(), &control_in));
Node* while_node;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(gpu_0_ph)});
AttrValue cond_func;
cond_func.mutable_func()->set_name("LessThanOrEqualToN");
AttrValue body_func;
body_func.mutable_func()->set_name("XTimesTwo");
TF_ASSERT_OK(
NodeBuilder("while", "While", &graph->flib_def())
.Input(inputs)
.ControlInput(control_in)
.Device(gpu_2_device)
.Attr("T", {type})
.Attr("cond", cond_func)
.Attr("body", body_func)
.Attr("parallel_iterations", 100)
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
.Finalize(graph.get(), &while_node));
// Create an empty Const node with control dep from the While op.
// This triggers the creation of a LoopExecuted node.
Node* control_out;
TensorProto proto;
proto.set_dtype(DT_FLOAT);
TensorShape empty_shape({0});
empty_shape.AsProto(proto.mutable_tensor_shape());
TF_ASSERT_OK(NodeBuilder("control_out", "Const")
.ControlInput(while_node)
.Attr("dtype", DT_FLOAT)
.Attr("value", proto)
.Finalize(graph.get(), &control_out));
TF_ASSERT_OK(Rewrite(&graph));
const Node* placeholder_node = nullptr;
for (const auto* op : graph->op_nodes()) {
if (op->name() == "placed_node") {
placeholder_node = op;
}
}
ASSERT_NE(placeholder_node, nullptr);
// Verify the requested device of the Enter node.
int enter_consumers = 0;
const Node* enter_node = nullptr;
for (const Node* consumer : placeholder_node->out_nodes()) {
if (consumer->type_string() == "Enter") {
enter_consumers += 1;
enter_node = consumer;
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
}
}
ASSERT_EQ(enter_consumers, 1);
// Verify the requested device of the Merge node.
int merge_consumers = 0;
const Node* merge_node = nullptr;
for (const Node* consumer : enter_node->out_nodes()) {
if (consumer->type_string() == "Merge") {
merge_consumers += 1;
merge_node = consumer;
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
}
}
ASSERT_EQ(merge_consumers, 1);
// Verify the requested device of the NextIteration node.
int next_iteration_consumers = 0;
for (const Node* consumer : merge_node->in_nodes()) {
if (consumer->type_string() == "NextIteration") {
next_iteration_consumers += 1;
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
}
}
ASSERT_EQ(next_iteration_consumers, 1);
// Verify the requested device of the Switch node.
int switch_consumers = 0;
const Node* switch_node = nullptr;
for (const Node* consumer : merge_node->out_nodes()) {
if (consumer->type_string() == "Switch") {
switch_consumers += 1;
switch_node = consumer;
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
}
}
ASSERT_EQ(switch_consumers, 1);
// Verify the requested device of the Exit node.
int exit_consumers = 0;
for (const Node* consumer : switch_node->out_nodes()) {
if (consumer->type_string() == "Exit") {
exit_consumers += 1;
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
}
}
ASSERT_EQ(exit_consumers, 1);
// Verify the requested device of LoopControlInputs.
const Node* loop_control_inputs_node = nullptr;
for (const auto* op : graph->op_nodes()) {
if (absl::StrContains(op->name(), "LoopControlInputs")) {
loop_control_inputs_node = op;
}
}
ASSERT_NE(loop_control_inputs_node, nullptr);
ASSERT_EQ(loop_control_inputs_node->requested_device(), gpu_2_device);
// Verify the requested device of LoopExecuted.
const Node* loop_executed_node = nullptr;
for (const auto* op : graph->op_nodes()) {
if (absl::StrContains(op->name(), "LoopExecuted")) {
loop_executed_node = op;
}
}
ASSERT_NE(loop_executed_node, nullptr);
ASSERT_EQ(loop_executed_node->requested_device(), gpu_2_device);
}
TEST(LowerWhileOpTest, MultipleInputs) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));