Make device manager in RunGrappler static instead of creating it on every invocation.

This is to avoid creating new devices every time RunGrappler() is called. And the optimized graph may contain tensor protos that are only valid when the corresponding devices are alive.

PiperOrigin-RevId: 353924393
Change-Id: Ibc3b2868f409690e31ccfae7e5f1d0626d34afff
This commit is contained in:
Kuangyuan Chen 2021-01-26 12:28:29 -08:00 committed by TensorFlower Gardener
parent 0f90066f9e
commit 8f595c9558
4 changed files with 59 additions and 24 deletions

View File

@ -1094,6 +1094,7 @@ cc_library(
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core/protobuf:for_core_protos_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support",
],
)

View File

@ -3426,13 +3426,22 @@ class SavedModelSignatureDefImporterLite {
Status SavedModelSignatureDefImporterLite::InitializeGraph(
MLIRImportOptions import_options) {
GraphDef graph_def;
if (import_options.enable_grappler) {
// Grappler is best-effort.
auto status = RunGrappler(&meta_graph_def_);
if (!status.ok()) LOG(WARNING) << status;
auto statusor = RunGrappler(meta_graph_def_);
if (statusor.ok()) {
graph_def = std::move(statusor).ValueOrDie();
} else {
// If the grappler fails, use the original graph def.
LOG(WARNING) << "SavedModelSignatureDefImporterLite: grappler failed: "
<< statusor.status();
graph_def = meta_graph_def_.graph_def();
}
} else {
graph_def = meta_graph_def_.graph_def();
}
GraphDef graph_def = meta_graph_def_.graph_def();
if (import_options.upgrade_legacy) {
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
graph_def, graph_->flib_def().default_registry()));

View File

@ -100,22 +100,33 @@ Status GenerateResourceSharedNameIfEmpty(
return tensorflow::Status::OK();
}
Status RunGrappler(MetaGraphDef* meta_graph_def) {
std::vector<std::unique_ptr<Device>> devices;
// Only CPU device is used so instead of calling DeviceFactory::AddDevices()
// with dummy session config, which will conflict with user defined options
// and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
// only devices.
DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
SessionOptions options;
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
options, "/job:localhost/replica:0/task:0", &devices));
Device* cpu_device = devices[0].get();
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
// The static device manager is used to avoid creating the new device every time
// RunGrappler() is called. In addition, the optimized graph may contain tensor
// protos that are only valid when the corresponding device is alive.
static const DeviceMgr* GetStaticDeviceMgr() {
static const auto* const device_mgr = []() -> const DeviceMgr* {
std::vector<std::unique_ptr<Device>> devices;
// Only CPU device is used so instead of calling DeviceFactory::AddDevices()
// with dummy session config, which will conflict with user defined options
// and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
// only devices.
DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
SessionOptions options;
auto status = cpu_factory->CreateDevices(
options, "/job:localhost/replica:0/task:0", &devices);
if (!status.ok()) {
LOG(ERROR) << "Failed to create devices for Grappler: " << status;
return nullptr;
}
DeviceSet dev_set;
for (auto d : device_mgr->ListDevices()) dev_set.AddDevice(d);
return new StaticDeviceMgr(std::move(devices));
}();
return device_mgr;
}
stream_executor::port::StatusOr<GraphDef> RunGrappler(
const MetaGraphDef& meta_graph_def) {
ConfigProto config_proto;
// Avoid grappler logic that lowers to v1 control flow.
config_proto.mutable_experimental()->set_use_tfrt(true);
@ -135,17 +146,29 @@ Status RunGrappler(MetaGraphDef* meta_graph_def) {
grappler::ItemConfig item_config;
item_config.ignore_user_placement = false;
std::unique_ptr<grappler::GrapplerItem> item =
grappler::GrapplerItemFromMetaGraphDef("graph", *meta_graph_def,
grappler::GrapplerItemFromMetaGraphDef("graph", meta_graph_def,
item_config);
if (!item) {
return tensorflow::errors::Internal(
"Failed to create grappler item from MetaGraphDef.");
}
const auto* device_mgr = GetStaticDeviceMgr();
if (!device_mgr) {
return tensorflow::errors::Internal(
"Failed to get devices in RunGrappler().");
}
DeviceSet dev_set;
for (auto* d : device_mgr->ListDevices()) dev_set.AddDevice(d);
grappler::VirtualCluster cluster(&dev_set);
return grappler::RunMetaOptimizer(std::move(*item), config_proto, cpu_device,
&cluster,
meta_graph_def->mutable_graph_def());
Device* cpu_device = device_mgr->HostCPU();
GraphDef output_graph_def;
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
std::move(*item), config_proto, cpu_device, &cluster, &output_graph_def));
return output_graph_def;
}
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -30,9 +31,10 @@ class MetaGraphDef;
Status GenerateResourceSharedNameIfEmpty(
GraphDef& gdef, const OpRegistryInterface* default_registry);
// Run grapler passes over `meta_graph_def`.graph_def(), and optimize it in
// place.
Status RunGrappler(MetaGraphDef* meta_graph_def);
// Run grapler passes over `meta_graph_def`.graph_def() and returns the
// optimized graphdef.
stream_executor::port::StatusOr<GraphDef> RunGrappler(
const MetaGraphDef& meta_graph_def);
} // namespace tensorflow