From 50204e4205f3c613ddd6a3ce2ba14f60d6aaaf80 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 3 Dec 2018 15:51:06 -0800 Subject: [PATCH] Run placer for the inlined function body. PiperOrigin-RevId: 223878376 --- .../common_runtime/graph_execution_state.cc | 7 ++ tensorflow/core/grappler/BUILD | 3 + tensorflow/core/grappler/grappler_item.cc | 64 +++++++++- tensorflow/core/grappler/grappler_item.h | 28 ++++- .../core/grappler/grappler_item_test.cc | 28 +++++ tensorflow/core/grappler/optimizers/BUILD | 1 + .../optimizers/arithmetic_optimizer.cc | 2 +- .../grappler/optimizers/function_optimizer.cc | 110 ++++++++++++++++-- .../optimizers/function_optimizer_test.cc | 53 +++++++++ .../grappler/optimizers/meta_optimizer.cc | 8 +- .../optimizers/meta_optimizer_test.cc | 2 +- .../core/grappler/utils/grappler_test.cc | 2 + 12 files changed, 291 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 9a56c671623..880806f120d 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -595,6 +595,13 @@ Status GraphExecutionState::OptimizeGraph( grappler::GrapplerItem item; item.id = "tf_graph"; graph_->ToGraphDef(&item.graph); + + // It's ok to skip invalid device annotations in Grappler. + Status inferred_devices = item.InferDevicesFromGraph(); + if (!inferred_devices.ok()) { + VLOG(3) << inferred_devices.error_message(); + } + // TODO(b/114748242): Add a unit test to test this bug fix. if (flib_def_) { *item.graph.mutable_library() = flib_def_->ToProto(); diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 7982b358538..f353d789d47 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -107,6 +107,8 @@ cc_library( ":utils", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -141,6 +143,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], ) diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 2c490f3966c..f7cda35368e 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -19,10 +19,13 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace grappler { @@ -38,7 +41,8 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) { restore_op = other.restore_op; save_restore_loc_tensor = other.save_restore_loc_tensor; queue_runners = other.queue_runners; - allowed_optimizations = other.allowed_optimizations; + devices_ = other.devices_; + allowed_optimizations_ = other.allowed_optimizations_; graph.Swap(graph_def); } @@ -111,6 +115,64 @@ std::unordered_set GrapplerItem::NodesToPreserve() const { return result; } +const std::unordered_set& GrapplerItem::devices() const { + return devices_; +} + +Status GrapplerItem::AddDevice(const string& device) { + DeviceNameUtils::ParsedName name; + + if (!DeviceNameUtils::ParseFullName(device, &name)) { + return errors::InvalidArgument("Invalid device name: device=", device); + + } else if (!name.has_job || !name.has_replica || !name.has_task || + !name.has_type || !name.has_id) { + return errors::InvalidArgument("Not a fully defined device name: device=", + device); + } + + devices_.insert(DeviceNameUtils::ParsedNameToString(name)); + return Status::OK(); +} + +Status GrapplerItem::AddDevices(const GrapplerItem& other) { + std::vector invalid_devices; + for (const string& device : other.devices()) { + Status added = AddDevice(device); + if (!added.ok()) invalid_devices.emplace_back(device); + } + return invalid_devices.empty() + ? Status::OK() + : errors::InvalidArgument("Skipped invalid devices: [", + absl::StrJoin(invalid_devices, ", "), + "]"); +} + +Status GrapplerItem::InferDevicesFromGraph() { + absl::flat_hash_set invalid_devices; + for (const NodeDef& node : graph.node()) { + Status added = AddDevice(node.device()); + if (!added.ok()) invalid_devices.insert(node.device()); + } + VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]"; + return invalid_devices.empty() + ? Status::OK() + : errors::InvalidArgument("Skipped invalid devices: [", + absl::StrJoin(invalid_devices, ", "), + "]"); +} + +void GrapplerItem::ClearDevices() { devices_.clear(); } + +const GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations() + const { + return allowed_optimizations_; +} + +GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations() { + return allowed_optimizations_; +} + std::vector ComputeTransitiveFanin( const GraphDef& graph, const std::vector& terminal_nodes) { bool ill_formed = false; diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index a0748abfe69..aea7d27792a 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -85,7 +85,33 @@ struct GrapplerItem { bool non_differentiable_rewrites = true; }; - AllowedOptimizations allowed_optimizations; + const std::unordered_set& devices() const; + // Adds a device to a set of available devices, only if it's a valid fully + // defined device name. Returns `Status::OK()` if successfully added a device, + // and an error otherwise. + Status AddDevice(const string& device); + // Adds all valid devices from the other Grappler item to the device set. + Status AddDevices(const GrapplerItem& other); + // Adds all valid devices from the nodes of the graph to the device set. + // Returns `Status::OK()` if all device annotations found in a graph are valid + // fully defined device names, and an error otherwise. + Status InferDevicesFromGraph(); + // Clears a set of available devices. + void ClearDevices(); + + const AllowedOptimizations& allowed_optimizations() const; + AllowedOptimizations& allowed_optimizations(); + + private: + // TODO(ezhulenev) Make GrapplerItem a class and hide all public data members. + // TODO(ezhulenev): Migrate all unordered collections to absl. + + // A set of fully defined device names that can be used to place the nodes of + // the `graph`. + // Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0" + std::unordered_set devices_; + + AllowedOptimizations allowed_optimizations_; }; // Return the transitive fanin of a set of terminal nodes. diff --git a/tensorflow/core/grappler/grappler_item_test.cc b/tensorflow/core/grappler/grappler_item_test.cc index 72a9f481cab..a8fbe356829 100644 --- a/tensorflow/core/grappler/grappler_item_test.cc +++ b/tensorflow/core/grappler/grappler_item_test.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/platform/test.h" @@ -44,6 +46,32 @@ TEST_F(GrapplerItemTest, Basic) { EXPECT_EQ(main_ops, graph_nodes); } +TEST_F(GrapplerItemTest, InferDevices) { + using test::function::NDef; + + const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0"; + const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1"; + const string cpu2 = "/device:CPU:2"; + + GrapplerItem item; + item.graph = test::function::GDef( + { + NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0), + NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1), + NDef("c", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu2), + }, + {} /* Empty function library */); + + ASSERT_FALSE(item.InferDevicesFromGraph().ok()); + + EXPECT_EQ(item.devices().size(), 2); + EXPECT_NE(item.devices().find(cpu0), item.devices().end()); + EXPECT_NE(item.devices().find(cpu1), item.devices().end()); + + item.ClearDevices(); + EXPECT_EQ(item.devices().size(), 0); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 68f70bf0df3..40eab8b9f01 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -141,6 +141,7 @@ cc_library( deps = [ ":graph_optimizer", "//tensorflow/core:core_cpu_base", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index e3ac89b50db..e41b1cf6840 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -3572,7 +3572,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, // Disable restricted graph rewrites. options_.unary_ops_composition &= - item.allowed_optimizations.non_differentiable_rewrites; + item.allowed_optimizations().non_differentiable_rewrites; if (options_.dedup_computations) { DedupComputations(); diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 3ffe8f1a0b0..2dd4ff10e43 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -22,8 +22,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_replace.h" #include "absl/strings/substitute.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.h" @@ -111,10 +114,26 @@ AttrSlice FunctionInstantiationAttributes(const FunctionDef& func, } } -class FakeCPUDevice : public Device { +// This is a fake device that should not be used for any op kernel execution, +// the only purpose of this device is to be passed as a part of DeviceSet to the +// Placer. +class FakeDevice : public Device { public: - FakeCPUDevice(Env* env, const DeviceAttributes& attr) : Device(env, attr) {} + FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {} + explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {} Status Sync() override { return Status::OK(); } + + private: + static DeviceAttributes attr(const string& device) { + DeviceNameUtils::ParsedName parsed_name; + bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name); + DCHECK(parsed) << "Failed to parse full device name: " << device; + + DeviceAttributes attr; + attr.set_name(device); + attr.set_device_type(parsed_name.type); + return attr; + } }; // -------------------------------------------------------------------------- // @@ -240,6 +259,7 @@ class FunctionOptimizerContext { graph_version_(item.graph.versions().producer()), opt_level_(opt_level), function_library_(OpRegistry::Global(), item.graph.library()), + available_device_names_(item.devices().begin(), item.devices().end()), graph_view_(&item.graph) { InitializeTrulyConstNodes(item); InitializeFetchNodes(item); @@ -275,6 +295,18 @@ class FunctionOptimizerContext { const gtl::FlatSet& fetch_tensors() const { return fetch_tensors_; } + const DeviceSet* devices() const { + // Create fake devices lazily only if we need a DeviceSet. + if (available_devices_.empty() && !available_device_names_.empty()) { + for (const string& name : available_device_names_) { + auto device = absl::make_unique(name); + available_device_set_.AddDevice(device.get()); + available_devices_.push_back(std::move(device)); + } + } + return &available_device_set_; + } + bool IsFetchNode(const string& node_name) const { return fetch_nodes_.find(node_name) != fetch_nodes_.end(); } @@ -350,11 +382,8 @@ class FunctionOptimizerContext { void InitializeFunctionLibraryRuntime() { if (!flr_) { Env* env = Env::Default(); - DeviceAttributes attr; - attr.set_name("/device:CPU:0"); - attr.set_device_type("CPU"); std::vector> devices; - devices.push_back(absl::make_unique(env, attr)); + devices.push_back(absl::make_unique(env, "/device:CPU:0")); device_mgr_ = absl::make_unique(std::move(devices)); OptimizerOptions optimizer_opts; optimizer_opts.set_do_function_inlining(true); @@ -375,6 +404,16 @@ class FunctionOptimizerContext { std::unique_ptr process_flr_; FunctionLibraryRuntime* flr_ = nullptr; + // Fully defined names of the devices available to the GrapplerItem. + const gtl::FlatSet available_device_names_; + + // List of available `FakedDevices` (lazily initialized, see devices()). + mutable std::vector> available_devices_; + + // DeviceSet of fake devices (`FakeDevice`) constructed from + // available_devices_ (lazily initialized). + mutable DeviceSet available_device_set_; + // Nodes that are Const and not in feed. std::unordered_map truly_const_nodes_; // Specialized functions. @@ -1305,7 +1344,58 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, // edges from inlined side-effectful ops. std::vector side_effectful_nodes; - for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) { + // ------------------------------------------------------------------------ // + // First we need to assign device placements to all function body nodes. + + GraphDef placed_graph_def; + + const DeviceSet* devices = ctx->devices(); + + if (devices->devices().empty()) { + // If there are no devices available for placer, we just put all nodes to + // the same device as a function caller node. This can happen if Grappler is + // running "offline", without active runtime session, for example as a part + // of a batch job for graph analysis/optimization. + VLOG(3) << "Assign function call node device to all function body nodes. " + << "Device: " << func_node.device(); + placed_graph_def = item.mutable_function_body(); + for (NodeDef& node : *placed_graph_def.mutable_node()) { + node.set_device(func_node.device()); + } + } else { + // If we are running in an active runtime session, Grappler will get the + // graph after initial placing is done, and we should have devices for the + // placer. + VLOG(3) << "Run placer for instantiated function body. Devices: [" + << absl::StrJoin( + devices->devices(), ", ", + [](string* out, const Device* d) { out->append(d->name()); }) + << "]"; + + // Construct a Graph object from the instantiated function body. + GraphConstructorOptions opts; + Graph graph(ctx->function_library()); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(opts, item.function_body(), &graph)); + + // Use function caller node device as a default for placer. + const Device* default_device = + devices->FindDeviceByName(func_node.device()); + + Placer placer(&graph, devices, nullptr, /* No session options */ + default_device); + TF_RETURN_IF_ERROR(placer.Run()); + + // Convert Graph back to the GraphDef. + graph.ToGraphDef(&placed_graph_def); + } + + // ------------------------------------------------------------------------ // + // After all nodes placed we need to prepare them for inlining into the + // optimized graph: turn placeholders into identities, update nodes + // connectivity, etc... + + for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) { if (item.IsInputPlaceholder(func_body_node.name())) { // Turn input placeholders into identity node. DCHECK_EQ(0, func_body_node.input_size()); @@ -1337,10 +1427,6 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, TF_RETURN_IF_ERROR( AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node)); - // TODO(ezhulenev): Call PartitionedCallOp to get placement for all function - // body nodes, for now just place it on function caller node device. - func_body_node.set_device(func_node.device()); - // If the function body has a side-effectful op, we double check that the // function call node has an output control edge, otherwise we can't safely // do inlining and guarantee that node will be executed. @@ -1362,7 +1448,7 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node, // TODO(ezhulenev): Inline nested indirect function calls. - // Move the node to the main graph. + // Move the node to the optimized graph. optimized_graph->add_node()->Swap(&func_body_node); } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index 98552e94494..93a2fcda7bf 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -881,6 +881,59 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) { test::ExpectTensorEqual(tensors_expected[0], tensors[0]); } +TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithDevicePlacement) { + using test::function::NDef; + using FDH = FunctionDefHelper; + + FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE); + + FunctionDef mul_func = FunctionDefHelper::Create( + "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + // Add device placement spec to the function body node. + (*mul_func.mutable_node_def())[0].set_device("/device:CPU:1"); + + // We need fully defined device names to run the placer for inlined function. + const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0"; + const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1"; + + // Build a graph to compute c = MyMul(a, b) + GrapplerItem item; + item.fetch = {"d"}; + item.graph = test::function::GDef( + {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0), + NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1), + NDef("c", "PartitionedCall", {"a", "b"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"Tout", DataTypeSlice{DT_FLOAT}}, + {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}}, + cpu0), + NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, cpu0)}, + // Function library. + {mul_func}); + ASSERT_TRUE(item.InferDevicesFromGraph().ok()); + + GraphDef optimized_graph; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph)); + + GraphDef expected = test::function::GDef( + {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0), + NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1), + + // Function must be inlined and `mul` node placed on a requested device. + NDef("c/x", "Identity", {"a:0"}, {{"T", DT_FLOAT}}, cpu1), + NDef("c/y", "Identity", {"b:0"}, {{"T", DT_FLOAT}}, cpu1), + NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, cpu1), + + NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, cpu0)}, + // Function library. + {mul_func}); + + CompareGraphs(expected, optimized_graph); +} + TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 304ddc7710a..5560e8d55f3 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -524,7 +524,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // can't perform non-differentiable rewrites. if (differentiable_functions.find(func_name) != differentiable_functions.end()) { - func_item.allowed_optimizations.non_differentiable_rewrites = false; + func_item.allowed_optimizations().non_differentiable_rewrites = false; + } + + // Function item is allowed to use all devices from the main graph. + Status added_devices = func_item.AddDevices(item); + if (!added_devices.ok()) { + VLOG(3) << added_devices.error_message(); } // Optimize function body graph. diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index b60aa256676..42b867b6ac1 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -108,7 +108,7 @@ class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer { GraphDef* optimized_graph) override { *optimized_graph = item.graph; if (allowed_optimizations_) { - allowed_optimizations_->insert({item.id, item.allowed_optimizations}); + allowed_optimizations_->insert({item.id, item.allowed_optimizations()}); } return Status::OK(); } diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 804fce5ce88..576494cad55 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -114,6 +114,8 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const { for (int i = 0; i < want.node_size(); ++i) { EXPECT_EQ(want.node(i).op(), got.node(i).op()); EXPECT_EQ(want.node(i).name(), got.node(i).name()); + EXPECT_EQ(want.node(i).device(), got.node(i).device()); + ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size()); for (int j = 0; j < want.node(i).input_size(); ++j) { const TensorId want_tensor = ParseTensorName(want.node(i).input(j));