diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index b4b5b869e74..5e958960d07 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -141,9 +141,12 @@ mlir::OwningModuleRef SavedModelV1ToMlirImport( absl::string_view saved_model_dir, const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) { tensorflow::SavedModelBundle bundle; - auto load_status = tensorflow::LoadSavedModel( - /* session_options = */ {}, /* run_options = */ {}, - std::string(saved_model_dir), tags, &bundle); + tensorflow::SessionOptions session_options; + // Force saved model states to be restored to CPU. + (*session_options.config.mutable_device_count())["GPU"] = 0; + auto load_status = + tensorflow::LoadSavedModel(session_options, /* run_options = */ {}, + std::string(saved_model_dir), tags, &bundle); if (!load_status.ok()) { LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir << "': " << load_status;