Use StatusOr to propagate detailed lowering error.
PiperOrigin-RevId: 320677957 Change-Id: Ic95be20f14f960b006d1d47972de07decfca3b77
This commit is contained in:
parent
d49636eeea
commit
9060253512
@ -188,20 +188,16 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
if (saved_model_version == 2) {
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else if (saved_model_version == 1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Should be either saved model v1 or v2");
|
||||
|
@ -1194,6 +1194,7 @@ cc_library(
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -40,9 +40,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
@ -98,7 +95,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
context);
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -112,13 +109,11 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
@ -128,18 +123,17 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
if (!load_status.ok()) {
|
||||
LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
|
||||
<< "': " << load_status;
|
||||
return nullptr;
|
||||
return load_status;
|
||||
}
|
||||
|
||||
auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
}
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
@ -154,19 +148,18 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
if (!load_status.ok()) {
|
||||
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
|
||||
<< "': " << load_status;
|
||||
return nullptr;
|
||||
return load_status;
|
||||
}
|
||||
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context,
|
||||
upgrade_legacy);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
}
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -180,7 +173,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "Graph import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
return module_or.status();
|
||||
}
|
||||
auto& module = module_or.ValueOrDie();
|
||||
std::srand(0);
|
||||
@ -215,7 +208,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
}
|
||||
}
|
||||
}
|
||||
return module_or.ConsumeValueOrDie();
|
||||
return module_or;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -23,15 +23,20 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// TODO(antiagainst): Directly manipulating files in library functions is not
|
||||
// a good idea. We should pass in a string/stream here.
|
||||
|
||||
// Converts a TensorFlow GraphDef stored in the file with the given
|
||||
// `input_filename` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -42,7 +47,7 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction(
|
||||
|
||||
// Similar as the above function, but replaces all constant tensors
|
||||
// with randomly generated splat values.
|
||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||
absl::string_view input_shapes, absl::string_view output_arrays,
|
||||
@ -54,7 +59,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
// Converts a TensorFlow SavedModel stored in the directory with the given
|
||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||
@ -62,7 +67,7 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
|
||||
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context,
|
||||
|
@ -42,11 +42,13 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
|
||||
static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input,
|
||||
MLIRContext* context) {
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
auto module_or = tensorflow::GraphdefToMlirTranslateFunction(
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, control_output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) return nullptr;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
static TranslateToMLIRRegistration GraphdefToMlirTranslate(
|
||||
@ -54,11 +56,13 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate(
|
||||
|
||||
static OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, MLIRContext* context) {
|
||||
return tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
auto module_or = tensorflow::GraphdefToSplattedMlirTranslateFunction(
|
||||
input, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays, control_output_arrays, prune_unused_nodes,
|
||||
convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
|
||||
enable_shape_inference, context);
|
||||
if (!module_or.status().ok()) return nullptr;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(
|
||||
|
@ -112,19 +112,19 @@ int main(int argc, char** argv) {
|
||||
if (import_saved_model_object_graph) {
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
if (!module_or.status().ok()) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
module_or.ConsumeValueOrDie()->print(output->os());
|
||||
} else if (import_saved_model_signature_defs) {
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
if (!module_or.status().ok()) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
module_or.ConsumeValueOrDie()->print(output->os());
|
||||
} else {
|
||||
auto input = mlir::openInputFile(input_filename, &error_message);
|
||||
|
||||
|
@ -129,20 +129,18 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
if (import_saved_model) {
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names), context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
|
||||
return module;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else if (import_saved_model_v1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
|
||||
return module;
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Should be either saved model v1 or v2");
|
||||
|
Loading…
x
Reference in New Issue
Block a user