Run placer for the inlined function body.
PiperOrigin-RevId: 223878376
This commit is contained in:
parent
296e4a05e4
commit
50204e4205
@ -595,6 +595,13 @@ Status GraphExecutionState::OptimizeGraph(
|
|||||||
grappler::GrapplerItem item;
|
grappler::GrapplerItem item;
|
||||||
item.id = "tf_graph";
|
item.id = "tf_graph";
|
||||||
graph_->ToGraphDef(&item.graph);
|
graph_->ToGraphDef(&item.graph);
|
||||||
|
|
||||||
|
// It's ok to skip invalid device annotations in Grappler.
|
||||||
|
Status inferred_devices = item.InferDevicesFromGraph();
|
||||||
|
if (!inferred_devices.ok()) {
|
||||||
|
VLOG(3) << inferred_devices.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(b/114748242): Add a unit test to test this bug fix.
|
// TODO(b/114748242): Add a unit test to test this bug fix.
|
||||||
if (flib_def_) {
|
if (flib_def_) {
|
||||||
*item.graph.mutable_library() = flib_def_->ToProto();
|
*item.graph.mutable_library() = flib_def_->ToProto();
|
||||||
|
@ -107,6 +107,8 @@ cc_library(
|
|||||||
":utils",
|
":utils",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,6 +143,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
@ -38,7 +41,8 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
|
|||||||
restore_op = other.restore_op;
|
restore_op = other.restore_op;
|
||||||
save_restore_loc_tensor = other.save_restore_loc_tensor;
|
save_restore_loc_tensor = other.save_restore_loc_tensor;
|
||||||
queue_runners = other.queue_runners;
|
queue_runners = other.queue_runners;
|
||||||
allowed_optimizations = other.allowed_optimizations;
|
devices_ = other.devices_;
|
||||||
|
allowed_optimizations_ = other.allowed_optimizations_;
|
||||||
graph.Swap(graph_def);
|
graph.Swap(graph_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,6 +115,64 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::unordered_set<string>& GrapplerItem::devices() const {
|
||||||
|
return devices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GrapplerItem::AddDevice(const string& device) {
|
||||||
|
DeviceNameUtils::ParsedName name;
|
||||||
|
|
||||||
|
if (!DeviceNameUtils::ParseFullName(device, &name)) {
|
||||||
|
return errors::InvalidArgument("Invalid device name: device=", device);
|
||||||
|
|
||||||
|
} else if (!name.has_job || !name.has_replica || !name.has_task ||
|
||||||
|
!name.has_type || !name.has_id) {
|
||||||
|
return errors::InvalidArgument("Not a fully defined device name: device=",
|
||||||
|
device);
|
||||||
|
}
|
||||||
|
|
||||||
|
devices_.insert(DeviceNameUtils::ParsedNameToString(name));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GrapplerItem::AddDevices(const GrapplerItem& other) {
|
||||||
|
std::vector<absl::string_view> invalid_devices;
|
||||||
|
for (const string& device : other.devices()) {
|
||||||
|
Status added = AddDevice(device);
|
||||||
|
if (!added.ok()) invalid_devices.emplace_back(device);
|
||||||
|
}
|
||||||
|
return invalid_devices.empty()
|
||||||
|
? Status::OK()
|
||||||
|
: errors::InvalidArgument("Skipped invalid devices: [",
|
||||||
|
absl::StrJoin(invalid_devices, ", "),
|
||||||
|
"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GrapplerItem::InferDevicesFromGraph() {
|
||||||
|
absl::flat_hash_set<absl::string_view> invalid_devices;
|
||||||
|
for (const NodeDef& node : graph.node()) {
|
||||||
|
Status added = AddDevice(node.device());
|
||||||
|
if (!added.ok()) invalid_devices.insert(node.device());
|
||||||
|
}
|
||||||
|
VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
|
||||||
|
return invalid_devices.empty()
|
||||||
|
? Status::OK()
|
||||||
|
: errors::InvalidArgument("Skipped invalid devices: [",
|
||||||
|
absl::StrJoin(invalid_devices, ", "),
|
||||||
|
"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
void GrapplerItem::ClearDevices() { devices_.clear(); }
|
||||||
|
|
||||||
|
const GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations()
|
||||||
|
const {
|
||||||
|
return allowed_optimizations_;
|
||||||
|
}
|
||||||
|
|
||||||
|
GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations() {
|
||||||
|
return allowed_optimizations_;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||||
const GraphDef& graph, const std::vector<string>& terminal_nodes) {
|
const GraphDef& graph, const std::vector<string>& terminal_nodes) {
|
||||||
bool ill_formed = false;
|
bool ill_formed = false;
|
||||||
|
@ -85,7 +85,33 @@ struct GrapplerItem {
|
|||||||
bool non_differentiable_rewrites = true;
|
bool non_differentiable_rewrites = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
AllowedOptimizations allowed_optimizations;
|
const std::unordered_set<string>& devices() const;
|
||||||
|
// Adds a device to a set of available devices, only if it's a valid fully
|
||||||
|
// defined device name. Returns `Status::OK()` if successfully added a device,
|
||||||
|
// and an error otherwise.
|
||||||
|
Status AddDevice(const string& device);
|
||||||
|
// Adds all valid devices from the other Grappler item to the device set.
|
||||||
|
Status AddDevices(const GrapplerItem& other);
|
||||||
|
// Adds all valid devices from the nodes of the graph to the device set.
|
||||||
|
// Returns `Status::OK()` if all device annotations found in a graph are valid
|
||||||
|
// fully defined device names, and an error otherwise.
|
||||||
|
Status InferDevicesFromGraph();
|
||||||
|
// Clears a set of available devices.
|
||||||
|
void ClearDevices();
|
||||||
|
|
||||||
|
const AllowedOptimizations& allowed_optimizations() const;
|
||||||
|
AllowedOptimizations& allowed_optimizations();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// TODO(ezhulenev) Make GrapplerItem a class and hide all public data members.
|
||||||
|
// TODO(ezhulenev): Migrate all unordered collections to absl.
|
||||||
|
|
||||||
|
// A set of fully defined device names that can be used to place the nodes of
|
||||||
|
// the `graph`.
|
||||||
|
// Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0"
|
||||||
|
std::unordered_set<string> devices_;
|
||||||
|
|
||||||
|
AllowedOptimizations allowed_optimizations_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Return the transitive fanin of a set of terminal nodes.
|
// Return the transitive fanin of a set of terminal nodes.
|
||||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -44,6 +46,32 @@ TEST_F(GrapplerItemTest, Basic) {
|
|||||||
EXPECT_EQ(main_ops, graph_nodes);
|
EXPECT_EQ(main_ops, graph_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GrapplerItemTest, InferDevices) {
|
||||||
|
using test::function::NDef;
|
||||||
|
|
||||||
|
const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
|
||||||
|
const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
|
||||||
|
const string cpu2 = "/device:CPU:2";
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{
|
||||||
|
NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
|
||||||
|
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
|
||||||
|
NDef("c", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu2),
|
||||||
|
},
|
||||||
|
{} /* Empty function library */);
|
||||||
|
|
||||||
|
ASSERT_FALSE(item.InferDevicesFromGraph().ok());
|
||||||
|
|
||||||
|
EXPECT_EQ(item.devices().size(), 2);
|
||||||
|
EXPECT_NE(item.devices().find(cpu0), item.devices().end());
|
||||||
|
EXPECT_NE(item.devices().find(cpu1), item.devices().end());
|
||||||
|
|
||||||
|
item.ClearDevices();
|
||||||
|
EXPECT_EQ(item.devices().size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -141,6 +141,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":graph_optimizer",
|
":graph_optimizer",
|
||||||
"//tensorflow/core:core_cpu_base",
|
"//tensorflow/core:core_cpu_base",
|
||||||
|
"//tensorflow/core:core_cpu_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
@ -3572,7 +3572,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
|
|||||||
|
|
||||||
// Disable restricted graph rewrites.
|
// Disable restricted graph rewrites.
|
||||||
options_.unary_ops_composition &=
|
options_.unary_ops_composition &=
|
||||||
item.allowed_optimizations.non_differentiable_rewrites;
|
item.allowed_optimizations().non_differentiable_rewrites;
|
||||||
|
|
||||||
if (options_.dedup_computations) {
|
if (options_.dedup_computations) {
|
||||||
DedupComputations();
|
DedupComputations();
|
||||||
|
@ -22,8 +22,11 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_replace.h"
|
#include "absl/strings/str_replace.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/placer.h"
|
||||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -111,10 +114,26 @@ AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class FakeCPUDevice : public Device {
|
// This is a fake device that should not be used for any op kernel execution,
|
||||||
|
// the only purpose of this device is to be passed as a part of DeviceSet to the
|
||||||
|
// Placer.
|
||||||
|
class FakeDevice : public Device {
|
||||||
public:
|
public:
|
||||||
FakeCPUDevice(Env* env, const DeviceAttributes& attr) : Device(env, attr) {}
|
FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
|
||||||
|
explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
|
||||||
Status Sync() override { return Status::OK(); }
|
Status Sync() override { return Status::OK(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
static DeviceAttributes attr(const string& device) {
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
|
||||||
|
DCHECK(parsed) << "Failed to parse full device name: " << device;
|
||||||
|
|
||||||
|
DeviceAttributes attr;
|
||||||
|
attr.set_name(device);
|
||||||
|
attr.set_device_type(parsed_name.type);
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// -------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------- //
|
||||||
@ -240,6 +259,7 @@ class FunctionOptimizerContext {
|
|||||||
graph_version_(item.graph.versions().producer()),
|
graph_version_(item.graph.versions().producer()),
|
||||||
opt_level_(opt_level),
|
opt_level_(opt_level),
|
||||||
function_library_(OpRegistry::Global(), item.graph.library()),
|
function_library_(OpRegistry::Global(), item.graph.library()),
|
||||||
|
available_device_names_(item.devices().begin(), item.devices().end()),
|
||||||
graph_view_(&item.graph) {
|
graph_view_(&item.graph) {
|
||||||
InitializeTrulyConstNodes(item);
|
InitializeTrulyConstNodes(item);
|
||||||
InitializeFetchNodes(item);
|
InitializeFetchNodes(item);
|
||||||
@ -275,6 +295,18 @@ class FunctionOptimizerContext {
|
|||||||
|
|
||||||
const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; }
|
const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; }
|
||||||
|
|
||||||
|
const DeviceSet* devices() const {
|
||||||
|
// Create fake devices lazily only if we need a DeviceSet.
|
||||||
|
if (available_devices_.empty() && !available_device_names_.empty()) {
|
||||||
|
for (const string& name : available_device_names_) {
|
||||||
|
auto device = absl::make_unique<FakeDevice>(name);
|
||||||
|
available_device_set_.AddDevice(device.get());
|
||||||
|
available_devices_.push_back(std::move(device));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &available_device_set_;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsFetchNode(const string& node_name) const {
|
bool IsFetchNode(const string& node_name) const {
|
||||||
return fetch_nodes_.find(node_name) != fetch_nodes_.end();
|
return fetch_nodes_.find(node_name) != fetch_nodes_.end();
|
||||||
}
|
}
|
||||||
@ -350,11 +382,8 @@ class FunctionOptimizerContext {
|
|||||||
void InitializeFunctionLibraryRuntime() {
|
void InitializeFunctionLibraryRuntime() {
|
||||||
if (!flr_) {
|
if (!flr_) {
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
DeviceAttributes attr;
|
|
||||||
attr.set_name("/device:CPU:0");
|
|
||||||
attr.set_device_type("CPU");
|
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
devices.push_back(absl::make_unique<FakeCPUDevice>(env, attr));
|
devices.push_back(absl::make_unique<FakeDevice>(env, "/device:CPU:0"));
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||||
OptimizerOptions optimizer_opts;
|
OptimizerOptions optimizer_opts;
|
||||||
optimizer_opts.set_do_function_inlining(true);
|
optimizer_opts.set_do_function_inlining(true);
|
||||||
@ -375,6 +404,16 @@ class FunctionOptimizerContext {
|
|||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
|
||||||
FunctionLibraryRuntime* flr_ = nullptr;
|
FunctionLibraryRuntime* flr_ = nullptr;
|
||||||
|
|
||||||
|
// Fully defined names of the devices available to the GrapplerItem.
|
||||||
|
const gtl::FlatSet<string> available_device_names_;
|
||||||
|
|
||||||
|
// List of available `FakedDevices` (lazily initialized, see devices()).
|
||||||
|
mutable std::vector<std::unique_ptr<Device>> available_devices_;
|
||||||
|
|
||||||
|
// DeviceSet of fake devices (`FakeDevice`) constructed from
|
||||||
|
// available_devices_ (lazily initialized).
|
||||||
|
mutable DeviceSet available_device_set_;
|
||||||
|
|
||||||
// Nodes that are Const and not in feed.
|
// Nodes that are Const and not in feed.
|
||||||
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
|
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
|
||||||
// Specialized functions.
|
// Specialized functions.
|
||||||
@ -1305,7 +1344,58 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
|||||||
// edges from inlined side-effectful ops.
|
// edges from inlined side-effectful ops.
|
||||||
std::vector<string> side_effectful_nodes;
|
std::vector<string> side_effectful_nodes;
|
||||||
|
|
||||||
for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) {
|
// ------------------------------------------------------------------------ //
|
||||||
|
// First we need to assign device placements to all function body nodes.
|
||||||
|
|
||||||
|
GraphDef placed_graph_def;
|
||||||
|
|
||||||
|
const DeviceSet* devices = ctx->devices();
|
||||||
|
|
||||||
|
if (devices->devices().empty()) {
|
||||||
|
// If there are no devices available for placer, we just put all nodes to
|
||||||
|
// the same device as a function caller node. This can happen if Grappler is
|
||||||
|
// running "offline", without active runtime session, for example as a part
|
||||||
|
// of a batch job for graph analysis/optimization.
|
||||||
|
VLOG(3) << "Assign function call node device to all function body nodes. "
|
||||||
|
<< "Device: " << func_node.device();
|
||||||
|
placed_graph_def = item.mutable_function_body();
|
||||||
|
for (NodeDef& node : *placed_graph_def.mutable_node()) {
|
||||||
|
node.set_device(func_node.device());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If we are running in an active runtime session, Grappler will get the
|
||||||
|
// graph after initial placing is done, and we should have devices for the
|
||||||
|
// placer.
|
||||||
|
VLOG(3) << "Run placer for instantiated function body. Devices: ["
|
||||||
|
<< absl::StrJoin(
|
||||||
|
devices->devices(), ", ",
|
||||||
|
[](string* out, const Device* d) { out->append(d->name()); })
|
||||||
|
<< "]";
|
||||||
|
|
||||||
|
// Construct a Graph object from the instantiated function body.
|
||||||
|
GraphConstructorOptions opts;
|
||||||
|
Graph graph(ctx->function_library());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ConvertGraphDefToGraph(opts, item.function_body(), &graph));
|
||||||
|
|
||||||
|
// Use function caller node device as a default for placer.
|
||||||
|
const Device* default_device =
|
||||||
|
devices->FindDeviceByName(func_node.device());
|
||||||
|
|
||||||
|
Placer placer(&graph, devices, nullptr, /* No session options */
|
||||||
|
default_device);
|
||||||
|
TF_RETURN_IF_ERROR(placer.Run());
|
||||||
|
|
||||||
|
// Convert Graph back to the GraphDef.
|
||||||
|
graph.ToGraphDef(&placed_graph_def);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// After all nodes placed we need to prepare them for inlining into the
|
||||||
|
// optimized graph: turn placeholders into identities, update nodes
|
||||||
|
// connectivity, etc...
|
||||||
|
|
||||||
|
for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
|
||||||
if (item.IsInputPlaceholder(func_body_node.name())) {
|
if (item.IsInputPlaceholder(func_body_node.name())) {
|
||||||
// Turn input placeholders into identity node.
|
// Turn input placeholders into identity node.
|
||||||
DCHECK_EQ(0, func_body_node.input_size());
|
DCHECK_EQ(0, func_body_node.input_size());
|
||||||
@ -1337,10 +1427,6 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
|
AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
|
||||||
|
|
||||||
// TODO(ezhulenev): Call PartitionedCallOp to get placement for all function
|
|
||||||
// body nodes, for now just place it on function caller node device.
|
|
||||||
func_body_node.set_device(func_node.device());
|
|
||||||
|
|
||||||
// If the function body has a side-effectful op, we double check that the
|
// If the function body has a side-effectful op, we double check that the
|
||||||
// function call node has an output control edge, otherwise we can't safely
|
// function call node has an output control edge, otherwise we can't safely
|
||||||
// do inlining and guarantee that node will be executed.
|
// do inlining and guarantee that node will be executed.
|
||||||
@ -1362,7 +1448,7 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
|||||||
|
|
||||||
// TODO(ezhulenev): Inline nested indirect function calls.
|
// TODO(ezhulenev): Inline nested indirect function calls.
|
||||||
|
|
||||||
// Move the node to the main graph.
|
// Move the node to the optimized graph.
|
||||||
optimized_graph->add_node()->Swap(&func_body_node);
|
optimized_graph->add_node()->Swap(&func_body_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -881,6 +881,59 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
|
|||||||
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
|
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithDevicePlacement) {
|
||||||
|
using test::function::NDef;
|
||||||
|
using FDH = FunctionDefHelper;
|
||||||
|
|
||||||
|
FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
|
||||||
|
|
||||||
|
FunctionDef mul_func = FunctionDefHelper::Create(
|
||||||
|
"MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
|
||||||
|
{{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
|
||||||
|
/* Mapping between function returns and function node outputs. */
|
||||||
|
{{"z", "mul:z:0"}});
|
||||||
|
// Add device placement spec to the function body node.
|
||||||
|
(*mul_func.mutable_node_def())[0].set_device("/device:CPU:1");
|
||||||
|
|
||||||
|
// We need fully defined device names to run the placer for inlined function.
|
||||||
|
const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
|
||||||
|
const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
|
||||||
|
|
||||||
|
// Build a graph to compute c = MyMul(a, b)
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"d"};
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
|
||||||
|
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
|
||||||
|
NDef("c", "PartitionedCall", {"a", "b"},
|
||||||
|
{{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
|
||||||
|
{"Tout", DataTypeSlice{DT_FLOAT}},
|
||||||
|
{"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
|
||||||
|
cpu0),
|
||||||
|
NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, cpu0)},
|
||||||
|
// Function library.
|
||||||
|
{mul_func});
|
||||||
|
ASSERT_TRUE(item.InferDevicesFromGraph().ok());
|
||||||
|
|
||||||
|
GraphDef optimized_graph;
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
|
||||||
|
|
||||||
|
GraphDef expected = test::function::GDef(
|
||||||
|
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
|
||||||
|
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
|
||||||
|
|
||||||
|
// Function must be inlined and `mul` node placed on a requested device.
|
||||||
|
NDef("c/x", "Identity", {"a:0"}, {{"T", DT_FLOAT}}, cpu1),
|
||||||
|
NDef("c/y", "Identity", {"b:0"}, {{"T", DT_FLOAT}}, cpu1),
|
||||||
|
NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, cpu1),
|
||||||
|
|
||||||
|
NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, cpu0)},
|
||||||
|
// Function library.
|
||||||
|
{mul_func});
|
||||||
|
|
||||||
|
CompareGraphs(expected, optimized_graph);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) {
|
TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) {
|
||||||
using test::function::NDef;
|
using test::function::NDef;
|
||||||
|
|
||||||
|
@ -524,7 +524,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// can't perform non-differentiable rewrites.
|
// can't perform non-differentiable rewrites.
|
||||||
if (differentiable_functions.find(func_name) !=
|
if (differentiable_functions.find(func_name) !=
|
||||||
differentiable_functions.end()) {
|
differentiable_functions.end()) {
|
||||||
func_item.allowed_optimizations.non_differentiable_rewrites = false;
|
func_item.allowed_optimizations().non_differentiable_rewrites = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function item is allowed to use all devices from the main graph.
|
||||||
|
Status added_devices = func_item.AddDevices(item);
|
||||||
|
if (!added_devices.ok()) {
|
||||||
|
VLOG(3) << added_devices.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optimize function body graph.
|
// Optimize function body graph.
|
||||||
|
@ -108,7 +108,7 @@ class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
|
|||||||
GraphDef* optimized_graph) override {
|
GraphDef* optimized_graph) override {
|
||||||
*optimized_graph = item.graph;
|
*optimized_graph = item.graph;
|
||||||
if (allowed_optimizations_) {
|
if (allowed_optimizations_) {
|
||||||
allowed_optimizations_->insert({item.id, item.allowed_optimizations});
|
allowed_optimizations_->insert({item.id, item.allowed_optimizations()});
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -114,6 +114,8 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
|
|||||||
for (int i = 0; i < want.node_size(); ++i) {
|
for (int i = 0; i < want.node_size(); ++i) {
|
||||||
EXPECT_EQ(want.node(i).op(), got.node(i).op());
|
EXPECT_EQ(want.node(i).op(), got.node(i).op());
|
||||||
EXPECT_EQ(want.node(i).name(), got.node(i).name());
|
EXPECT_EQ(want.node(i).name(), got.node(i).name());
|
||||||
|
EXPECT_EQ(want.node(i).device(), got.node(i).device());
|
||||||
|
|
||||||
ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size());
|
ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size());
|
||||||
for (int j = 0; j < want.node(i).input_size(); ++j) {
|
for (int j = 0; j < want.node(i).input_size(); ++j) {
|
||||||
const TensorId want_tensor = ParseTensorName(want.node(i).input(j));
|
const TensorId want_tensor = ParseTensorName(want.node(i).input(j));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user