Make device manager in RunGrappler static instead of creating it on every invocation.
This is to avoid creating new devices every time RunGrappler() is called. And the optimized graph may contain tensor protos that are only valid when the corresponding devices are alive. PiperOrigin-RevId: 353924393 Change-Id: Ibc3b2868f409690e31ccfae7e5f1d0626d34afff
This commit is contained in:
parent
0f90066f9e
commit
8f595c9558
@ -1094,6 +1094,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
"//tensorflow/core/protobuf:for_core_protos_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
@ -3426,13 +3426,22 @@ class SavedModelSignatureDefImporterLite {
|
||||
|
||||
Status SavedModelSignatureDefImporterLite::InitializeGraph(
|
||||
MLIRImportOptions import_options) {
|
||||
GraphDef graph_def;
|
||||
if (import_options.enable_grappler) {
|
||||
// Grappler is best-effort.
|
||||
auto status = RunGrappler(&meta_graph_def_);
|
||||
if (!status.ok()) LOG(WARNING) << status;
|
||||
auto statusor = RunGrappler(meta_graph_def_);
|
||||
if (statusor.ok()) {
|
||||
graph_def = std::move(statusor).ValueOrDie();
|
||||
} else {
|
||||
// If the grappler fails, use the original graph def.
|
||||
LOG(WARNING) << "SavedModelSignatureDefImporterLite: grappler failed: "
|
||||
<< statusor.status();
|
||||
graph_def = meta_graph_def_.graph_def();
|
||||
}
|
||||
} else {
|
||||
graph_def = meta_graph_def_.graph_def();
|
||||
}
|
||||
|
||||
GraphDef graph_def = meta_graph_def_.graph_def();
|
||||
if (import_options.upgrade_legacy) {
|
||||
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
|
||||
graph_def, graph_->flib_def().default_registry()));
|
||||
|
@ -100,22 +100,33 @@ Status GenerateResourceSharedNameIfEmpty(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
Status RunGrappler(MetaGraphDef* meta_graph_def) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
// Only CPU device is used so instead of calling DeviceFactory::AddDevices()
|
||||
// with dummy session config, which will conflict with user defined options
|
||||
// and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
|
||||
// only devices.
|
||||
DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
|
||||
SessionOptions options;
|
||||
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices));
|
||||
Device* cpu_device = devices[0].get();
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
// The static device manager is used to avoid creating the new device every time
|
||||
// RunGrappler() is called. In addition, the optimized graph may contain tensor
|
||||
// protos that are only valid when the corresponding device is alive.
|
||||
static const DeviceMgr* GetStaticDeviceMgr() {
|
||||
static const auto* const device_mgr = []() -> const DeviceMgr* {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
// Only CPU device is used so instead of calling DeviceFactory::AddDevices()
|
||||
// with dummy session config, which will conflict with user defined options
|
||||
// and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
|
||||
// only devices.
|
||||
DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
|
||||
SessionOptions options;
|
||||
auto status = cpu_factory->CreateDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to create devices for Grappler: " << status;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DeviceSet dev_set;
|
||||
for (auto d : device_mgr->ListDevices()) dev_set.AddDevice(d);
|
||||
return new StaticDeviceMgr(std::move(devices));
|
||||
}();
|
||||
|
||||
return device_mgr;
|
||||
}
|
||||
|
||||
stream_executor::port::StatusOr<GraphDef> RunGrappler(
|
||||
const MetaGraphDef& meta_graph_def) {
|
||||
ConfigProto config_proto;
|
||||
// Avoid grappler logic that lowers to v1 control flow.
|
||||
config_proto.mutable_experimental()->set_use_tfrt(true);
|
||||
@ -135,17 +146,29 @@ Status RunGrappler(MetaGraphDef* meta_graph_def) {
|
||||
grappler::ItemConfig item_config;
|
||||
item_config.ignore_user_placement = false;
|
||||
std::unique_ptr<grappler::GrapplerItem> item =
|
||||
grappler::GrapplerItemFromMetaGraphDef("graph", *meta_graph_def,
|
||||
grappler::GrapplerItemFromMetaGraphDef("graph", meta_graph_def,
|
||||
item_config);
|
||||
if (!item) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed to create grappler item from MetaGraphDef.");
|
||||
}
|
||||
|
||||
const auto* device_mgr = GetStaticDeviceMgr();
|
||||
if (!device_mgr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed to get devices in RunGrappler().");
|
||||
}
|
||||
|
||||
DeviceSet dev_set;
|
||||
for (auto* d : device_mgr->ListDevices()) dev_set.AddDevice(d);
|
||||
grappler::VirtualCluster cluster(&dev_set);
|
||||
return grappler::RunMetaOptimizer(std::move(*item), config_proto, cpu_device,
|
||||
&cluster,
|
||||
meta_graph_def->mutable_graph_def());
|
||||
Device* cpu_device = device_mgr->HostCPU();
|
||||
|
||||
GraphDef output_graph_def;
|
||||
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
|
||||
std::move(*item), config_proto, cpu_device, &cluster, &output_graph_def));
|
||||
|
||||
return output_graph_def;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -30,9 +31,10 @@ class MetaGraphDef;
|
||||
Status GenerateResourceSharedNameIfEmpty(
|
||||
GraphDef& gdef, const OpRegistryInterface* default_registry);
|
||||
|
||||
// Run grapler passes over `meta_graph_def`.graph_def(), and optimize it in
|
||||
// place.
|
||||
Status RunGrappler(MetaGraphDef* meta_graph_def);
|
||||
// Run grapler passes over `meta_graph_def`.graph_def() and returns the
|
||||
// optimized graphdef.
|
||||
stream_executor::port::StatusOr<GraphDef> RunGrappler(
|
||||
const MetaGraphDef& meta_graph_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user