From aebcab97dfe3e3b5f0764d23489bf20cc8d42e6d Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Thu, 14 Jan 2021 17:52:15 -0800 Subject: [PATCH] Handle grappler failure in savedmodel importer properly PiperOrigin-RevId: 351914138 Change-Id: I94db8b279385ea91483ef9f47bb43b7fe23180ed --- .../mlir/tensorflow/translate/import_model.cc | 11 ++++++----- .../mlir/tensorflow/translate/upgrade_graph.cc | 4 ++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 19d558ad19b..6735449a2b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -3368,7 +3368,7 @@ class SavedModelSignatureDefImporterLite { const GraphDebugInfo& debug_info() const { return debug_info_; } private: - const MetaGraphDef& meta_graph_def_; + MetaGraphDef meta_graph_def_; const GraphDebugInfo& debug_info_; std::unique_ptr graph_; absl::Span exported_names_; @@ -3379,10 +3379,11 @@ class SavedModelSignatureDefImporterLite { Status SavedModelSignatureDefImporterLite::InitializeGraph( MLIRImportOptions import_options) { - // TODO(jpienaar): Remove need to const_cast. - if (import_options.enable_grappler) - TF_RETURN_IF_ERROR( - RunGrappler(const_cast(&meta_graph_def_))); + if (import_options.enable_grappler) { + // Grappler is best-effort. + auto status = RunGrappler(&meta_graph_def_); + if (!status.ok()) LOG(WARNING) << status; + } GraphDef graph_def = meta_graph_def_.graph_def(); if (import_options.upgrade_legacy) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc index 7a944960c58..ec2c4d3b5b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc @@ -130,6 +130,10 @@ Status RunGrappler(MetaGraphDef* meta_graph_def) { std::unique_ptr item = grappler::GrapplerItemFromMetaGraphDef("graph", *meta_graph_def, grappler::ItemConfig()); + if (!item) { + return tensorflow::errors::Internal( + "Failed to create grappler item from MetaGraphDef."); + } grappler::VirtualCluster cluster(&dev_set); return grappler::RunMetaOptimizer(std::move(*item), config_proto, cpu_device,