[Grappler] Do not copy device set when optimizing function library
PiperOrigin-RevId: 242042808
This commit is contained in:
parent
3f56ae58ef
commit
51a3acd6a0
@ -1500,9 +1500,16 @@ Status PlaceInlinedFunctionBody(
|
||||
const GraphView::OutputPort output_port =
|
||||
ctx->graph_view().GetRegularFanin({&func_node, input_idx});
|
||||
|
||||
const string& input_device = output_port.node->device();
|
||||
|
||||
if (!input_device.empty()) {
|
||||
VLOG(3) << "Pin inlined function input node '" << func_body_node->name()
|
||||
<< "' to the '" << output_port.node->device() << "' device.";
|
||||
func_body_node->set_requested_device(output_port.node->device());
|
||||
} else {
|
||||
VLOG(3) << "Inlined function input node '" << func_body_node->name()
|
||||
<< "' device is undefined.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1513,15 +1520,14 @@ Status PlaceInlinedFunctionBody(
|
||||
const DeviceSet* devices = ctx->devices();
|
||||
|
||||
if (devices->devices().empty()) {
|
||||
// If there are no devices available for placer, we just put all nodes to
|
||||
// the same device as a function caller node. This can happen if Grappler is
|
||||
// running "offline", without active runtime session, for example as a part
|
||||
// of a batch job for graph analysis/optimization.
|
||||
VLOG(3) << "Assign function call node device to all function body nodes. "
|
||||
<< "Device: " << func_node.device();
|
||||
for (Node* func_body_node : func_body_graph->nodes()) {
|
||||
func_body_node->set_requested_device(func_node.device());
|
||||
}
|
||||
// If there are no devices available for placer, we do not place function
|
||||
// body nodes. This happens when Grappler optimizing function library, or
|
||||
// when graph optimized "offline", without active runtime session, for
|
||||
// example as a part of batch job for graph analysis/optimization.
|
||||
// GrapplerItem instantiated from a function library doesn't have to be
|
||||
// fully placed after all optimization, it will be placed by the function
|
||||
// library runtime before execution.
|
||||
VLOG(3) << "Do not place instantiated function body.";
|
||||
} else {
|
||||
// If we are running in an active runtime session, Grappler will get the
|
||||
// graph after initial placing is done, and we should have devices for the
|
||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
constexpr char kDevice[] = "/device:CPU:0";
|
||||
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
} // namespace
|
||||
|
||||
class FunctionOptimizerTest : public GrapplerTest {
|
||||
@ -731,9 +731,14 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionSimpleFunction) {
|
||||
{"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
|
||||
kDevice),
|
||||
NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, kDevice)},
|
||||
// Function library.
|
||||
{mul_func} /* Function library */);
|
||||
|
||||
Tensor pi = test::AsScalar<float>(3.14f);
|
||||
item.feed.emplace_back("a", pi);
|
||||
item.feed.emplace_back("b", pi);
|
||||
|
||||
// If device set is empty, inlined function body must not be placed.
|
||||
{
|
||||
GraphDef optimized_graph;
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
|
||||
|
||||
@ -741,7 +746,39 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionSimpleFunction) {
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
|
||||
// Function must be inlined and all nodes placed on a valid device.
|
||||
// Function body nodes are not placed, however function input nodes
|
||||
// must copy device assignment from input arguments.
|
||||
NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}),
|
||||
NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
|
||||
kDevice),
|
||||
NDef("c/y", "Identity", {"b:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
|
||||
kDevice),
|
||||
NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}),
|
||||
|
||||
NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
|
||||
// Function library.
|
||||
{mul_func});
|
||||
|
||||
CompareGraphs(expected, optimized_graph);
|
||||
|
||||
GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
|
||||
auto tensors_expected = EvaluateFetchNodes(item);
|
||||
auto tensors = EvaluateFetchNodes(optimized);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
ASSERT_EQ(tensors.size(), tensors_expected.size());
|
||||
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
|
||||
}
|
||||
|
||||
// If device set is not empty, inlined function body must be placed.
|
||||
{
|
||||
GraphDef optimized_graph;
|
||||
TF_EXPECT_OK(item.AddDevice(kDevice));
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
|
||||
|
||||
GraphDef expected = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
|
||||
NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}, kDevice),
|
||||
NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
|
||||
kDevice),
|
||||
@ -755,16 +792,13 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionSimpleFunction) {
|
||||
|
||||
CompareGraphs(expected, optimized_graph);
|
||||
|
||||
Tensor pi = test::AsScalar<float>(3.14f);
|
||||
item.feed.emplace_back("a", pi);
|
||||
item.feed.emplace_back("b", pi);
|
||||
|
||||
GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
|
||||
auto tensors_expected = EvaluateFetchNodes(item);
|
||||
auto tensors = EvaluateFetchNodes(optimized);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
ASSERT_EQ(tensors.size(), tensors_expected.size());
|
||||
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
|
||||
@ -799,6 +833,7 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
|
||||
// f2 = MyMul(f1, f1, v)
|
||||
// return [f2, v]
|
||||
GrapplerItem item;
|
||||
TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
|
||||
item.fetch = {"out_1", "out_2"};
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
@ -992,6 +1027,7 @@ TEST_F(FunctionOptimizerTest,
|
||||
// f2 = MyMul(a, b, ^f1) <-- control dependency on inlined function!
|
||||
// return f2
|
||||
GrapplerItem item;
|
||||
TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
|
||||
item.fetch = {"out"};
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
@ -1167,6 +1203,7 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithMergedDeadTensors) {
|
||||
// return out
|
||||
//
|
||||
GrapplerItem item;
|
||||
TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
|
||||
item.fetch = {"out"};
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
@ -1255,6 +1292,7 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithNestedFunctionCall) {
|
||||
// c = Identity(b)
|
||||
// return c
|
||||
GrapplerItem item;
|
||||
TF_EXPECT_OK(item.AddDevice(kDevice)); // device for placing inlined function
|
||||
item.fetch = {"c"};
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
|
@ -604,10 +604,12 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
false;
|
||||
}
|
||||
|
||||
// Function item is allowed to use all devices from the main graph.
|
||||
Status added_devices = func_item.AddDevices(item);
|
||||
if (!added_devices.ok()) {
|
||||
VLOG(3) << added_devices.error_message();
|
||||
// Device set available to the function is defined only by the runtime,
|
||||
// when we instantiate and execute the function. We can't use all devices
|
||||
// available to the main graph, because after partitioning the function
|
||||
// call node might execute on a remote worker.
|
||||
if (!func_item.devices().empty()) {
|
||||
return errors::Internal("GrapplerFunctionItem devices must be empty.");
|
||||
}
|
||||
|
||||
// We are not allowed to prune certain types of ops from the graph
|
||||
|
Loading…
x
Reference in New Issue
Block a user