Run placer for the inlined function body.

PiperOrigin-RevId: 223878376
This commit is contained in:
Eugene Zhulenev 2018-12-03 15:51:06 -08:00 committed by TensorFlower Gardener
parent 296e4a05e4
commit 50204e4205
12 changed files with 291 additions and 17 deletions

View File

@ -595,6 +595,13 @@ Status GraphExecutionState::OptimizeGraph(
grappler::GrapplerItem item; grappler::GrapplerItem item;
item.id = "tf_graph"; item.id = "tf_graph";
graph_->ToGraphDef(&item.graph); graph_->ToGraphDef(&item.graph);
// It's ok to skip invalid device annotations in Grappler.
Status inferred_devices = item.InferDevicesFromGraph();
if (!inferred_devices.ok()) {
VLOG(3) << inferred_devices.error_message();
}
// TODO(b/114748242): Add a unit test to test this bug fix. // TODO(b/114748242): Add a unit test to test this bug fix.
if (flib_def_) { if (flib_def_) {
*item.graph.mutable_library() = flib_def_->ToProto(); *item.graph.mutable_library() = flib_def_->ToProto();

View File

@ -107,6 +107,8 @@ cc_library(
":utils", ":utils",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
], ],
) )
@ -141,6 +143,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
], ],
) )

View File

@ -19,10 +19,13 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
@ -38,7 +41,8 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
restore_op = other.restore_op; restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor; save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners; queue_runners = other.queue_runners;
allowed_optimizations = other.allowed_optimizations; devices_ = other.devices_;
allowed_optimizations_ = other.allowed_optimizations_;
graph.Swap(graph_def); graph.Swap(graph_def);
} }
@ -111,6 +115,64 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
return result; return result;
} }
const std::unordered_set<string>& GrapplerItem::devices() const {
return devices_;
}
Status GrapplerItem::AddDevice(const string& device) {
DeviceNameUtils::ParsedName name;
if (!DeviceNameUtils::ParseFullName(device, &name)) {
return errors::InvalidArgument("Invalid device name: device=", device);
} else if (!name.has_job || !name.has_replica || !name.has_task ||
!name.has_type || !name.has_id) {
return errors::InvalidArgument("Not a fully defined device name: device=",
device);
}
devices_.insert(DeviceNameUtils::ParsedNameToString(name));
return Status::OK();
}
Status GrapplerItem::AddDevices(const GrapplerItem& other) {
std::vector<absl::string_view> invalid_devices;
for (const string& device : other.devices()) {
Status added = AddDevice(device);
if (!added.ok()) invalid_devices.emplace_back(device);
}
return invalid_devices.empty()
? Status::OK()
: errors::InvalidArgument("Skipped invalid devices: [",
absl::StrJoin(invalid_devices, ", "),
"]");
}
Status GrapplerItem::InferDevicesFromGraph() {
absl::flat_hash_set<absl::string_view> invalid_devices;
for (const NodeDef& node : graph.node()) {
Status added = AddDevice(node.device());
if (!added.ok()) invalid_devices.insert(node.device());
}
VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
return invalid_devices.empty()
? Status::OK()
: errors::InvalidArgument("Skipped invalid devices: [",
absl::StrJoin(invalid_devices, ", "),
"]");
}
void GrapplerItem::ClearDevices() { devices_.clear(); }
const GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations()
const {
return allowed_optimizations_;
}
GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations() {
return allowed_optimizations_;
}
std::vector<const NodeDef*> ComputeTransitiveFanin( std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes) { const GraphDef& graph, const std::vector<string>& terminal_nodes) {
bool ill_formed = false; bool ill_formed = false;

View File

@ -85,7 +85,33 @@ struct GrapplerItem {
bool non_differentiable_rewrites = true; bool non_differentiable_rewrites = true;
}; };
AllowedOptimizations allowed_optimizations; const std::unordered_set<string>& devices() const;
// Adds a device to a set of available devices, only if it's a valid fully
// defined device name. Returns `Status::OK()` if successfully added a device,
// and an error otherwise.
Status AddDevice(const string& device);
// Adds all valid devices from the other Grappler item to the device set.
Status AddDevices(const GrapplerItem& other);
// Adds all valid devices from the nodes of the graph to the device set.
// Returns `Status::OK()` if all device annotations found in a graph are valid
// fully defined device names, and an error otherwise.
Status InferDevicesFromGraph();
// Clears a set of available devices.
void ClearDevices();
const AllowedOptimizations& allowed_optimizations() const;
AllowedOptimizations& allowed_optimizations();
private:
// TODO(ezhulenev) Make GrapplerItem a class and hide all public data members.
// TODO(ezhulenev): Migrate all unordered collections to absl.
// A set of fully defined device names that can be used to place the nodes of
// the `graph`.
// Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0"
std::unordered_set<string> devices_;
AllowedOptimizations allowed_optimizations_;
}; };
// Return the transitive fanin of a set of terminal nodes. // Return the transitive fanin of a set of terminal nodes.

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -44,6 +46,32 @@ TEST_F(GrapplerItemTest, Basic) {
EXPECT_EQ(main_ops, graph_nodes); EXPECT_EQ(main_ops, graph_nodes);
} }
TEST_F(GrapplerItemTest, InferDevices) {
using test::function::NDef;
const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
const string cpu2 = "/device:CPU:2";
GrapplerItem item;
item.graph = test::function::GDef(
{
NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
NDef("c", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu2),
},
{} /* Empty function library */);
ASSERT_FALSE(item.InferDevicesFromGraph().ok());
EXPECT_EQ(item.devices().size(), 2);
EXPECT_NE(item.devices().find(cpu0), item.devices().end());
EXPECT_NE(item.devices().find(cpu1), item.devices().end());
item.ClearDevices();
EXPECT_EQ(item.devices().size(), 0);
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -141,6 +141,7 @@ cc_library(
deps = [ deps = [
":graph_optimizer", ":graph_optimizer",
"//tensorflow/core:core_cpu_base", "//tensorflow/core:core_cpu_base",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -3572,7 +3572,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
// Disable restricted graph rewrites. // Disable restricted graph rewrites.
options_.unary_ops_composition &= options_.unary_ops_composition &=
item.allowed_optimizations.non_differentiable_rewrites; item.allowed_optimizations().non_differentiable_rewrites;
if (options_.dedup_computations) { if (options_.dedup_computations) {
DedupComputations(); DedupComputations();

View File

@ -22,8 +22,11 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
@ -111,10 +114,26 @@ AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
} }
} }
class FakeCPUDevice : public Device { // This is a fake device that should not be used for any op kernel execution,
// the only purpose of this device is to be passed as a part of DeviceSet to the
// Placer.
class FakeDevice : public Device {
public: public:
FakeCPUDevice(Env* env, const DeviceAttributes& attr) : Device(env, attr) {} FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }
private:
static DeviceAttributes attr(const string& device) {
DeviceNameUtils::ParsedName parsed_name;
bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
DCHECK(parsed) << "Failed to parse full device name: " << device;
DeviceAttributes attr;
attr.set_name(device);
attr.set_device_type(parsed_name.type);
return attr;
}
}; };
// -------------------------------------------------------------------------- // // -------------------------------------------------------------------------- //
@ -240,6 +259,7 @@ class FunctionOptimizerContext {
graph_version_(item.graph.versions().producer()), graph_version_(item.graph.versions().producer()),
opt_level_(opt_level), opt_level_(opt_level),
function_library_(OpRegistry::Global(), item.graph.library()), function_library_(OpRegistry::Global(), item.graph.library()),
available_device_names_(item.devices().begin(), item.devices().end()),
graph_view_(&item.graph) { graph_view_(&item.graph) {
InitializeTrulyConstNodes(item); InitializeTrulyConstNodes(item);
InitializeFetchNodes(item); InitializeFetchNodes(item);
@ -275,6 +295,18 @@ class FunctionOptimizerContext {
const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; } const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; }
const DeviceSet* devices() const {
// Create fake devices lazily only if we need a DeviceSet.
if (available_devices_.empty() && !available_device_names_.empty()) {
for (const string& name : available_device_names_) {
auto device = absl::make_unique<FakeDevice>(name);
available_device_set_.AddDevice(device.get());
available_devices_.push_back(std::move(device));
}
}
return &available_device_set_;
}
bool IsFetchNode(const string& node_name) const { bool IsFetchNode(const string& node_name) const {
return fetch_nodes_.find(node_name) != fetch_nodes_.end(); return fetch_nodes_.find(node_name) != fetch_nodes_.end();
} }
@ -350,11 +382,8 @@ class FunctionOptimizerContext {
void InitializeFunctionLibraryRuntime() { void InitializeFunctionLibraryRuntime() {
if (!flr_) { if (!flr_) {
Env* env = Env::Default(); Env* env = Env::Default();
DeviceAttributes attr;
attr.set_name("/device:CPU:0");
attr.set_device_type("CPU");
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
devices.push_back(absl::make_unique<FakeCPUDevice>(env, attr)); devices.push_back(absl::make_unique<FakeDevice>(env, "/device:CPU:0"));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
OptimizerOptions optimizer_opts; OptimizerOptions optimizer_opts;
optimizer_opts.set_do_function_inlining(true); optimizer_opts.set_do_function_inlining(true);
@ -375,6 +404,16 @@ class FunctionOptimizerContext {
std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_; std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
FunctionLibraryRuntime* flr_ = nullptr; FunctionLibraryRuntime* flr_ = nullptr;
// Fully defined names of the devices available to the GrapplerItem.
const gtl::FlatSet<string> available_device_names_;
// List of available `FakedDevices` (lazily initialized, see devices()).
mutable std::vector<std::unique_ptr<Device>> available_devices_;
// DeviceSet of fake devices (`FakeDevice`) constructed from
// available_devices_ (lazily initialized).
mutable DeviceSet available_device_set_;
// Nodes that are Const and not in feed. // Nodes that are Const and not in feed.
std::unordered_map<string, const NodeDef*> truly_const_nodes_; std::unordered_map<string, const NodeDef*> truly_const_nodes_;
// Specialized functions. // Specialized functions.
@ -1305,7 +1344,58 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
// edges from inlined side-effectful ops. // edges from inlined side-effectful ops.
std::vector<string> side_effectful_nodes; std::vector<string> side_effectful_nodes;
for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) { // ------------------------------------------------------------------------ //
// First we need to assign device placements to all function body nodes.
GraphDef placed_graph_def;
const DeviceSet* devices = ctx->devices();
if (devices->devices().empty()) {
// If there are no devices available for placer, we just put all nodes to
// the same device as a function caller node. This can happen if Grappler is
// running "offline", without active runtime session, for example as a part
// of a batch job for graph analysis/optimization.
VLOG(3) << "Assign function call node device to all function body nodes. "
<< "Device: " << func_node.device();
placed_graph_def = item.mutable_function_body();
for (NodeDef& node : *placed_graph_def.mutable_node()) {
node.set_device(func_node.device());
}
} else {
// If we are running in an active runtime session, Grappler will get the
// graph after initial placing is done, and we should have devices for the
// placer.
VLOG(3) << "Run placer for instantiated function body. Devices: ["
<< absl::StrJoin(
devices->devices(), ", ",
[](string* out, const Device* d) { out->append(d->name()); })
<< "]";
// Construct a Graph object from the instantiated function body.
GraphConstructorOptions opts;
Graph graph(ctx->function_library());
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(opts, item.function_body(), &graph));
// Use function caller node device as a default for placer.
const Device* default_device =
devices->FindDeviceByName(func_node.device());
Placer placer(&graph, devices, nullptr, /* No session options */
default_device);
TF_RETURN_IF_ERROR(placer.Run());
// Convert Graph back to the GraphDef.
graph.ToGraphDef(&placed_graph_def);
}
// ------------------------------------------------------------------------ //
// After all nodes placed we need to prepare them for inlining into the
// optimized graph: turn placeholders into identities, update nodes
// connectivity, etc...
for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
if (item.IsInputPlaceholder(func_body_node.name())) { if (item.IsInputPlaceholder(func_body_node.name())) {
// Turn input placeholders into identity node. // Turn input placeholders into identity node.
DCHECK_EQ(0, func_body_node.input_size()); DCHECK_EQ(0, func_body_node.input_size());
@ -1337,10 +1427,6 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node)); AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
// TODO(ezhulenev): Call PartitionedCallOp to get placement for all function
// body nodes, for now just place it on function caller node device.
func_body_node.set_device(func_node.device());
// If the function body has a side-effectful op, we double check that the // If the function body has a side-effectful op, we double check that the
// function call node has an output control edge, otherwise we can't safely // function call node has an output control edge, otherwise we can't safely
// do inlining and guarantee that node will be executed. // do inlining and guarantee that node will be executed.
@ -1362,7 +1448,7 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
// TODO(ezhulenev): Inline nested indirect function calls. // TODO(ezhulenev): Inline nested indirect function calls.
// Move the node to the main graph. // Move the node to the optimized graph.
optimized_graph->add_node()->Swap(&func_body_node); optimized_graph->add_node()->Swap(&func_body_node);
} }

View File

@ -881,6 +881,59 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
} }
TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithDevicePlacement) {
using test::function::NDef;
using FDH = FunctionDefHelper;
FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
FunctionDef mul_func = FunctionDefHelper::Create(
"MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
{{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
/* Mapping between function returns and function node outputs. */
{{"z", "mul:z:0"}});
// Add device placement spec to the function body node.
(*mul_func.mutable_node_def())[0].set_device("/device:CPU:1");
// We need fully defined device names to run the placer for inlined function.
const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
// Build a graph to compute c = MyMul(a, b)
GrapplerItem item;
item.fetch = {"d"};
item.graph = test::function::GDef(
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
NDef("c", "PartitionedCall", {"a", "b"},
{{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
{"Tout", DataTypeSlice{DT_FLOAT}},
{"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
cpu0),
NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, cpu0)},
// Function library.
{mul_func});
ASSERT_TRUE(item.InferDevicesFromGraph().ok());
GraphDef optimized_graph;
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
GraphDef expected = test::function::GDef(
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
// Function must be inlined and `mul` node placed on a requested device.
NDef("c/x", "Identity", {"a:0"}, {{"T", DT_FLOAT}}, cpu1),
NDef("c/y", "Identity", {"b:0"}, {{"T", DT_FLOAT}}, cpu1),
NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, cpu1),
NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, cpu0)},
// Function library.
{mul_func});
CompareGraphs(expected, optimized_graph);
}
TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) { TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) {
using test::function::NDef; using test::function::NDef;

View File

@ -524,7 +524,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// can't perform non-differentiable rewrites. // can't perform non-differentiable rewrites.
if (differentiable_functions.find(func_name) != if (differentiable_functions.find(func_name) !=
differentiable_functions.end()) { differentiable_functions.end()) {
func_item.allowed_optimizations.non_differentiable_rewrites = false; func_item.allowed_optimizations().non_differentiable_rewrites = false;
}
// Function item is allowed to use all devices from the main graph.
Status added_devices = func_item.AddDevices(item);
if (!added_devices.ok()) {
VLOG(3) << added_devices.error_message();
} }
// Optimize function body graph. // Optimize function body graph.

View File

@ -108,7 +108,7 @@ class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
GraphDef* optimized_graph) override { GraphDef* optimized_graph) override {
*optimized_graph = item.graph; *optimized_graph = item.graph;
if (allowed_optimizations_) { if (allowed_optimizations_) {
allowed_optimizations_->insert({item.id, item.allowed_optimizations}); allowed_optimizations_->insert({item.id, item.allowed_optimizations()});
} }
return Status::OK(); return Status::OK();
} }

View File

@ -114,6 +114,8 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
for (int i = 0; i < want.node_size(); ++i) { for (int i = 0; i < want.node_size(); ++i) {
EXPECT_EQ(want.node(i).op(), got.node(i).op()); EXPECT_EQ(want.node(i).op(), got.node(i).op());
EXPECT_EQ(want.node(i).name(), got.node(i).name()); EXPECT_EQ(want.node(i).name(), got.node(i).name());
EXPECT_EQ(want.node(i).device(), got.node(i).device());
ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size()); ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size());
for (int j = 0; j < want.node(i).input_size(); ++j) { for (int j = 0; j < want.node(i).input_size(); ++j) {
const TensorId want_tensor = ParseTensorName(want.node(i).input(j)); const TensorId want_tensor = ParseTensorName(want.node(i).input(j));