Make device manager in RunGrappler static instead of creating it on every invocation.

This is to avoid creating new devices every time RunGrappler() is called. And the optimized graph may contain tensor protos that are only valid when the corresponding devices are alive.

PiperOrigin-RevId: 353924393
Change-Id: Ibc3b2868f409690e31ccfae7e5f1d0626d34afff
This commit is contained in:
Kuangyuan Chen 2021-01-26 12:28:29 -08:00 committed by TensorFlower Gardener
parent 0f90066f9e
commit 8f595c9558
4 changed files with 59 additions and 24 deletions

View File

@ -1094,6 +1094,7 @@ cc_library(
"//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/optimizers:meta_optimizer", "//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/protobuf:for_core_protos_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
], ],
) )

View File

@ -3426,13 +3426,22 @@ class SavedModelSignatureDefImporterLite {
Status SavedModelSignatureDefImporterLite::InitializeGraph( Status SavedModelSignatureDefImporterLite::InitializeGraph(
MLIRImportOptions import_options) { MLIRImportOptions import_options) {
GraphDef graph_def;
if (import_options.enable_grappler) { if (import_options.enable_grappler) {
// Grappler is best-effort. // Grappler is best-effort.
auto status = RunGrappler(&meta_graph_def_); auto statusor = RunGrappler(meta_graph_def_);
if (!status.ok()) LOG(WARNING) << status; if (statusor.ok()) {
graph_def = std::move(statusor).ValueOrDie();
} else {
// If the grappler fails, use the original graph def.
LOG(WARNING) << "SavedModelSignatureDefImporterLite: grappler failed: "
<< statusor.status();
graph_def = meta_graph_def_.graph_def();
}
} else {
graph_def = meta_graph_def_.graph_def();
} }
GraphDef graph_def = meta_graph_def_.graph_def();
if (import_options.upgrade_legacy) { if (import_options.upgrade_legacy) {
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty( TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
graph_def, graph_->flib_def().default_registry())); graph_def, graph_->flib_def().default_registry()));

View File

@ -100,22 +100,33 @@ Status GenerateResourceSharedNameIfEmpty(
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
Status RunGrappler(MetaGraphDef* meta_graph_def) { // The static device manager is used to avoid creating the new device every time
std::vector<std::unique_ptr<Device>> devices; // RunGrappler() is called. In addition, the optimized graph may contain tensor
// Only CPU device is used so instead of calling DeviceFactory::AddDevices() // protos that are only valid when the corresponding device is alive.
// with dummy session config, which will conflict with user defined options static const DeviceMgr* GetStaticDeviceMgr() {
// and create unwanted devices, call cpu_factory->CreateDevices() to get CPU static const auto* const device_mgr = []() -> const DeviceMgr* {
// only devices. std::vector<std::unique_ptr<Device>> devices;
DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU"); // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
SessionOptions options; // with dummy session config, which will conflict with user defined options
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices( // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
options, "/job:localhost/replica:0/task:0", &devices)); // only devices.
Device* cpu_device = devices[0].get(); DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices)); SessionOptions options;
auto status = cpu_factory->CreateDevices(
options, "/job:localhost/replica:0/task:0", &devices);
if (!status.ok()) {
LOG(ERROR) << "Failed to create devices for Grappler: " << status;
return nullptr;
}
DeviceSet dev_set; return new StaticDeviceMgr(std::move(devices));
for (auto d : device_mgr->ListDevices()) dev_set.AddDevice(d); }();
return device_mgr;
}
stream_executor::port::StatusOr<GraphDef> RunGrappler(
const MetaGraphDef& meta_graph_def) {
ConfigProto config_proto; ConfigProto config_proto;
// Avoid grappler logic that lowers to v1 control flow. // Avoid grappler logic that lowers to v1 control flow.
config_proto.mutable_experimental()->set_use_tfrt(true); config_proto.mutable_experimental()->set_use_tfrt(true);
@ -135,17 +146,29 @@ Status RunGrappler(MetaGraphDef* meta_graph_def) {
grappler::ItemConfig item_config; grappler::ItemConfig item_config;
item_config.ignore_user_placement = false; item_config.ignore_user_placement = false;
std::unique_ptr<grappler::GrapplerItem> item = std::unique_ptr<grappler::GrapplerItem> item =
grappler::GrapplerItemFromMetaGraphDef("graph", *meta_graph_def, grappler::GrapplerItemFromMetaGraphDef("graph", meta_graph_def,
item_config); item_config);
if (!item) { if (!item) {
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
"Failed to create grappler item from MetaGraphDef."); "Failed to create grappler item from MetaGraphDef.");
} }
const auto* device_mgr = GetStaticDeviceMgr();
if (!device_mgr) {
return tensorflow::errors::Internal(
"Failed to get devices in RunGrappler().");
}
DeviceSet dev_set;
for (auto* d : device_mgr->ListDevices()) dev_set.AddDevice(d);
grappler::VirtualCluster cluster(&dev_set); grappler::VirtualCluster cluster(&dev_set);
return grappler::RunMetaOptimizer(std::move(*item), config_proto, cpu_device, Device* cpu_device = device_mgr->HostCPU();
&cluster,
meta_graph_def->mutable_graph_def()); GraphDef output_graph_def;
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
std::move(*item), config_proto, cpu_device, &cluster, &output_graph_def));
return output_graph_def;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow { namespace tensorflow {
@ -30,9 +31,10 @@ class MetaGraphDef;
Status GenerateResourceSharedNameIfEmpty( Status GenerateResourceSharedNameIfEmpty(
GraphDef& gdef, const OpRegistryInterface* default_registry); GraphDef& gdef, const OpRegistryInterface* default_registry);
// Run grapler passes over `meta_graph_def`.graph_def(), and optimize it in // Run grapler passes over `meta_graph_def`.graph_def() and returns the
// place. // optimized graphdef.
Status RunGrappler(MetaGraphDef* meta_graph_def); stream_executor::port::StatusOr<GraphDef> RunGrappler(
const MetaGraphDef& meta_graph_def);
} // namespace tensorflow } // namespace tensorflow