Use StatusOr to propagate detailed lowering error.

PiperOrigin-RevId: 320677957
Change-Id: Ic95be20f14f960b006d1d47972de07decfca3b77
This commit is contained in:
Chuanhao Zhuge 2020-07-10 14:35:38 -07:00 committed by TensorFlower Gardener
parent d49636eeea
commit 9060253512
7 changed files with 45 additions and 48 deletions

View File

@ -188,20 +188,16 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
if (saved_model_version == 2) {
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
return module;
if (!module_or.status().ok()) return module_or.status();
return module_or.ConsumeValueOrDie();
} else if (saved_model_version == 1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
return module;
if (!module_or.status().ok()) return module_or.status();
return module_or.ConsumeValueOrDie();
} else {
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");

View File

@ -1194,6 +1194,7 @@ cc_library(
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/utils:transitive_fanin",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",

View File

@ -40,9 +40,6 @@ limitations under the License.
namespace tensorflow {
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
@ -98,7 +95,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
context);
}
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -112,13 +109,11 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
enable_shape_inference, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr;
}
return module_or.ConsumeValueOrDie();
return module_or;
}
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
@ -128,18 +123,17 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
if (!load_status.ok()) {
LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
<< "': " << load_status;
return nullptr;
return load_status;
}
auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel import failed: " << module_or.status();
return nullptr;
}
return module_or.ConsumeValueOrDie();
return module_or;
}
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
@ -154,19 +148,18 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
if (!load_status.ok()) {
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
<< "': " << load_status;
return nullptr;
return load_status;
}
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context,
upgrade_legacy);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
return nullptr;
}
return module_or.ConsumeValueOrDie();
return module_or;
}
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -180,7 +173,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
enable_shape_inference, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "Graph import failed: " << module_or.status();
return nullptr;
return module_or.status();
}
auto& module = module_or.ValueOrDie();
std::srand(0);
@ -215,7 +208,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
}
}
}
return module_or.ConsumeValueOrDie();
return module_or;
}
} // namespace tensorflow

View File

@ -23,15 +23,20 @@ limitations under the License.
#include "absl/types/span.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
// TODO(antiagainst): Directly manipulating files in library functions is not
// a good idea. We should pass in a string/stream here.
// Converts a TensorFlow GraphDef stored in the file with the given
// `input_filename` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -42,7 +47,7 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
// Similar as the above function, but replaces all constant tensors
// with randomly generated splat values.
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes,
absl::string_view input_shapes, absl::string_view output_arrays,
@ -54,7 +59,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
// Converts a TensorFlow SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
@ -62,7 +67,7 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context,

View File

@ -42,11 +42,13 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
MLIRContext* context) {
return tensorflow::GraphdefToMlirTranslateFunction(
auto module_or = tensorflow::GraphdefToMlirTranslateFunction(
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, control_output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
enable_shape_inference, context);
if (!module_or.status().ok()) return nullptr;
return module_or.ConsumeValueOrDie();
}
static TranslateToMLIRRegistration GraphdefToMlirTranslate(
@ -54,11 +56,13 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, MLIRContext* context) {
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
auto module_or = tensorflow::GraphdefToSplattedMlirTranslateFunction(
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays, control_output_arrays, prune_unused_nodes,
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
enable_shape_inference, context);
if (!module_or.status().ok()) return nullptr;
return module_or.ConsumeValueOrDie();
}
static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(

View File

@ -112,19 +112,19 @@ int main(int argc, char** argv) {
if (import_saved_model_object_graph) {
mlir::MLIRContext context;
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, exported_names, &context);
if (!module) return 1;
if (!module_or.status().ok()) return 1;
module->print(output->os());
module_or.ConsumeValueOrDie()->print(output->os());
} else if (import_saved_model_signature_defs) {
mlir::MLIRContext context;
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, &context);
if (!module) return 1;
if (!module_or.status().ok()) return 1;
module->print(output->os());
module_or.ConsumeValueOrDie()->print(output->os());
} else {
auto input = mlir::openInputFile(input_filename, &error_message);

View File

@ -129,20 +129,18 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_in_vector);
if (import_saved_model) {
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
if (!module_or.status().ok()) return module_or.status();
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
return module;
return module_or.ConsumeValueOrDie();
} else if (import_saved_model_v1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
if (!module_or.status().ok()) return module_or.status();
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
return module;
return module_or.ConsumeValueOrDie();
} else {
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");