Fix requested_device of NextIteration node.
PiperOrigin-RevId: 310629600 Change-Id: Ic097448918a59b7ae42683a5f8d2014f97e22447
This commit is contained in:
parent
0c5b1b8ab2
commit
f596266023
tensorflow/core/common_runtime
@ -2511,10 +2511,12 @@ tf_cc_test(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -444,15 +444,14 @@ Status LowerWhileHelper::CreateNextIterationNodes() {
|
||||
if (IsResource(i)) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(NewName("next_iteration"), "NextIteration",
|
||||
graph_->op_registry(), &debug_info_)
|
||||
.Input(NodeOut(body_call_node_, i))
|
||||
.ControlInput(body_call_node_)
|
||||
.Device(while_op_->requested_device())
|
||||
.AssignedDevice(merge_nodes_[op_input_output_to_lowered_node_[i]]
|
||||
->assigned_device_name())
|
||||
.Finalize(graph_, &next_iteration));
|
||||
Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
|
||||
graph_->op_registry(), &debug_info_)
|
||||
.Input(NodeOut(body_call_node_, i))
|
||||
.ControlInput(body_call_node_)
|
||||
.Device(merge_node->requested_device())
|
||||
.AssignedDevice(merge_node->assigned_device_name())
|
||||
.Finalize(graph_, &next_iteration));
|
||||
next_iterations_nodes_.emplace_back(next_iteration);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
@ -25,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -262,6 +264,145 @@ TEST(LowerWhileOpTest, ForwardAssignedInputDevice) {
|
||||
ASSERT_EQ(exit_consumers, 1);
|
||||
}
|
||||
|
||||
TEST(LowerWhileOpTest, ForwardRequestedInputDevice) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
// Add test functions for cond and body.
|
||||
FunctionDefLibrary f_lib_proto;
|
||||
*f_lib_proto.add_function() = test::function::XTimesTwo();
|
||||
*f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
|
||||
|
||||
TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto));
|
||||
auto type = DT_FLOAT;
|
||||
// We will place the loop var on the gpu:0.
|
||||
const string gpu_0_device = "/job:localhost/replica:0/task:0/gpu:0";
|
||||
// We will place loop's control input on the gpu:1.
|
||||
const string gpu_1_device = "/job:localhost/replica:0/task:0/gpu:1";
|
||||
// We will place While op on gpu:2.
|
||||
const string gpu_2_device = "/job:localhost/replica:0/task:0/gpu:2";
|
||||
Node* gpu_0_ph;
|
||||
TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder")
|
||||
.Attr("dtype", type)
|
||||
.Device(gpu_0_device)
|
||||
.Finalize(graph.get(), &gpu_0_ph));
|
||||
Node* control_in;
|
||||
// Add a control input to the While op to trigger the creation of a
|
||||
// LoopExecuted node.
|
||||
TF_CHECK_OK(NodeBuilder("control_in", "Placeholder")
|
||||
.Attr("dtype", type)
|
||||
.Device(gpu_1_device)
|
||||
.Finalize(graph.get(), &control_in));
|
||||
Node* while_node;
|
||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(gpu_0_ph)});
|
||||
AttrValue cond_func;
|
||||
cond_func.mutable_func()->set_name("LessThanOrEqualToN");
|
||||
AttrValue body_func;
|
||||
body_func.mutable_func()->set_name("XTimesTwo");
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("while", "While", &graph->flib_def())
|
||||
.Input(inputs)
|
||||
.ControlInput(control_in)
|
||||
.Device(gpu_2_device)
|
||||
.Attr("T", {type})
|
||||
.Attr("cond", cond_func)
|
||||
.Attr("body", body_func)
|
||||
.Attr("parallel_iterations", 100)
|
||||
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||
.Finalize(graph.get(), &while_node));
|
||||
|
||||
// Create an empty Const node with control dep from the While op.
|
||||
// This triggers the creation of a LoopExecuted node.
|
||||
Node* control_out;
|
||||
TensorProto proto;
|
||||
proto.set_dtype(DT_FLOAT);
|
||||
TensorShape empty_shape({0});
|
||||
empty_shape.AsProto(proto.mutable_tensor_shape());
|
||||
TF_ASSERT_OK(NodeBuilder("control_out", "Const")
|
||||
.ControlInput(while_node)
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("value", proto)
|
||||
.Finalize(graph.get(), &control_out));
|
||||
|
||||
TF_ASSERT_OK(Rewrite(&graph));
|
||||
|
||||
const Node* placeholder_node = nullptr;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
if (op->name() == "placed_node") {
|
||||
placeholder_node = op;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(placeholder_node, nullptr);
|
||||
// Verify the requested device of the Enter node.
|
||||
int enter_consumers = 0;
|
||||
const Node* enter_node = nullptr;
|
||||
for (const Node* consumer : placeholder_node->out_nodes()) {
|
||||
if (consumer->type_string() == "Enter") {
|
||||
enter_consumers += 1;
|
||||
enter_node = consumer;
|
||||
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(enter_consumers, 1);
|
||||
// Verify the requested device of the Merge node.
|
||||
int merge_consumers = 0;
|
||||
const Node* merge_node = nullptr;
|
||||
for (const Node* consumer : enter_node->out_nodes()) {
|
||||
if (consumer->type_string() == "Merge") {
|
||||
merge_consumers += 1;
|
||||
merge_node = consumer;
|
||||
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(merge_consumers, 1);
|
||||
// Verify the requested device of the NextIteration node.
|
||||
int next_iteration_consumers = 0;
|
||||
for (const Node* consumer : merge_node->in_nodes()) {
|
||||
if (consumer->type_string() == "NextIteration") {
|
||||
next_iteration_consumers += 1;
|
||||
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(next_iteration_consumers, 1);
|
||||
// Verify the requested device of the Switch node.
|
||||
int switch_consumers = 0;
|
||||
const Node* switch_node = nullptr;
|
||||
for (const Node* consumer : merge_node->out_nodes()) {
|
||||
if (consumer->type_string() == "Switch") {
|
||||
switch_consumers += 1;
|
||||
switch_node = consumer;
|
||||
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(switch_consumers, 1);
|
||||
// Verify the requested device of the Exit node.
|
||||
int exit_consumers = 0;
|
||||
for (const Node* consumer : switch_node->out_nodes()) {
|
||||
if (consumer->type_string() == "Exit") {
|
||||
exit_consumers += 1;
|
||||
ASSERT_EQ(consumer->requested_device(), gpu_0_device);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(exit_consumers, 1);
|
||||
// Verify the requested device of LoopControlInputs.
|
||||
const Node* loop_control_inputs_node = nullptr;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
if (absl::StrContains(op->name(), "LoopControlInputs")) {
|
||||
loop_control_inputs_node = op;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(loop_control_inputs_node, nullptr);
|
||||
ASSERT_EQ(loop_control_inputs_node->requested_device(), gpu_2_device);
|
||||
// Verify the requested device of LoopExecuted.
|
||||
const Node* loop_executed_node = nullptr;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
if (absl::StrContains(op->name(), "LoopExecuted")) {
|
||||
loop_executed_node = op;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(loop_executed_node, nullptr);
|
||||
ASSERT_EQ(loop_executed_node->requested_device(), gpu_2_device);
|
||||
}
|
||||
|
||||
TEST(LowerWhileOpTest, MultipleInputs) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user