diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 31078e92ce5..e9b69501c8a 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -118,6 +118,39 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "bundle_v2", + srcs = ["bundle_v2.cc"], + hdrs = ["bundle_v2.h"], + deps = [ + ":constants", + "@com_google_absl//absl/container:flat_hash_set", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:strcat", + "//tensorflow/core/util/tensor_bundle", + ] + if_not_mobile([ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ]), +) + +tf_cc_test( + name = "bundle_v2_test", + srcs = ["bundle_v2_test.cc"], + data = [ + ":saved_model_half_plus_two", + ], + linkstatic = 1, + deps = [ + ":bundle_v2", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/platform:test", + ], +) + tf_cc_test( name = "saved_model_bundle_test", srcs = ["saved_model_bundle_test.cc"], @@ -160,6 +193,26 @@ tf_cc_test( ], ) +# A subset of the TF2 saved models can be generated with this tool. +py_binary( + name = "testdata/generate_saved_models", + srcs = ["testdata/generate_saved_models.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:variables", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/module", + "//tensorflow/python/saved_model", + "//tensorflow/python/saved_model:save_options", + "@absl_py//absl:app", + ], +) + # TODO(b/32673259): add a test to continuously validate these files. filegroup( name = "saved_model_half_plus_two", @@ -169,5 +222,7 @@ filegroup( "testdata/half_plus_two/**", "testdata/half_plus_two_v2/**", "testdata/x_plus_y_v2_debuginfo/**", + "testdata/CyclicModule/**", + "testdata/VarsAndArithmeticObjectGraph/**", ]), ) diff --git a/tensorflow/cc/saved_model/bundle_v2.cc b/tensorflow/cc/saved_model/bundle_v2.cc new file mode 100644 index 00000000000..b6daece84ab --- /dev/null +++ b/tensorflow/cc/saved_model/bundle_v2.cc @@ -0,0 +1,223 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/bundle_v2.h" + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" + +namespace tensorflow { + +namespace { + +Status ReadSavedModelProto(const string& export_dir, + SavedModel* saved_model_proto) { + LOG(INFO) << "Reading SavedModel from: " << export_dir; + + const string saved_model_pb_path = + io::JoinPath(export_dir, kSavedModelFilenamePb); + if (Env::Default()->FileExists(saved_model_pb_path).ok()) { + return ReadBinaryProto(Env::Default(), saved_model_pb_path, + saved_model_proto); + } + const string saved_model_pbtxt_path = + io::JoinPath(export_dir, kSavedModelFilenamePbTxt); + if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { + return ReadTextProto(Env::Default(), saved_model_pbtxt_path, + saved_model_proto); + } + return Status(error::Code::NOT_FOUND, + "Could not find SavedModel .pb or .pbtxt at supplied export " + "directory path: " + + export_dir); +} + +Status ReadSavedModelDebugInfoIfPresent( + const string& export_dir, + std::unique_ptr<GraphDebugInfo>* debug_info_proto) { + LOG(INFO) << "Reading SavedModel debug info (if present) from: " + << export_dir; + + const string debug_info_pb_path = + io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb"); + if (Env::Default()->FileExists(debug_info_pb_path).ok()) { + GraphDebugInfo debug_info; + TF_RETURN_IF_ERROR( + ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); + *debug_info_proto = + absl::make_unique<GraphDebugInfo>(std::move(debug_info)); + } + return Status::OK(); +} + +Status ReadCheckpointObjectGraph(BundleReader* bundle_reader, + TrackableObjectGraph* object_graph) { + Tensor object_graph_tensor; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor), + "SavedModel checkpoint does not contain object graph."); + if (object_graph_tensor.dtype() != DT_STRING || + object_graph_tensor.dims() != 0 || + object_graph_tensor.NumElements() != 1) { + return Status( + error::Code::FAILED_PRECONDITION, + "SavedModel checkpoint object graph was not the correct type."); + } + + const tstring* object_graph_string = reinterpret_cast<const tstring*>( + object_graph_tensor.tensor_data().data()); + if (!object_graph->ParseFromString(*object_graph_string)) { + return Status( + error::Code::FAILED_PRECONDITION, + "SavedModel checkpoint object graph could not be deserialized."); + } + return Status::OK(); +} + +} // namespace + +Status SavedModelV2Bundle::Load(const std::string& export_dir, + SavedModelV2Bundle* const bundle) { + SavedModel saved_model_proto; + TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto)); + + // Load MetaGraphDef. + // In version 2 SavedModels, there is only one MetaGraphDef. + if (saved_model_proto.meta_graphs_size() != 1) { + return Status( + error::Code::INVALID_ARGUMENT, + strings::StrCat( + "SavedModelV2 should have exactly one MetaGraphDef but actually ", + "contains ", saved_model_proto.meta_graphs_size())); + } + bundle->meta_graph_def_ = + std::move(*saved_model_proto.mutable_meta_graphs(0)); + + // Load GraphDebugInfo. + TF_RETURN_IF_ERROR( + ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_)); + + // Load the variables checkpoint reader. + const std::string variables_prefix = io::JoinPath( + export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename); + bundle->variable_reader_.reset( + new BundleReader(Env::Default(), variables_prefix)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + bundle->variable_reader_->status(), + "Unable to load SavedModel variables checkpoint from ", variables_prefix); + + // Deserialize the object graph proto from the tensor bundle. + TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph( + bundle->variable_reader_.get(), &bundle->trackable_object_graph_)); + return Status::OK(); +} + +Status SavedModelV2Bundle::VisitObjectsToRestore( + RestoreObjectsCallback callback) { + if (saved_object_graph().nodes_size() == 0 || + trackable_object_graph().nodes_size() == 0) { + return Status::OK(); + } + + // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph + // and descend to leaves. Note that the TrackableObjectGraph can have cycles + // (as can the SavedObjectGraph). + // This is detected and cycle edges are skipped. + const SavedObject* root_saved_object = &saved_object_graph().nodes(0); + const TrackableObjectGraph::TrackableObject* root_trackable_object = + &trackable_object_graph().nodes(0); + absl::flat_hash_set<int> trackable_node_ids; + return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object, + std::string(), &trackable_node_ids, + std::move(callback)); +} + +Status SavedModelV2Bundle::RecurseObjectsToRestore( + const SavedObject* saved_object, int saved_object_node_id, + const TrackableObjectGraph::TrackableObject* trackable_object, + std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids, + RestoreObjectsCallback callback) { + // Callback if any attributes or slot variables. + // Note that the root is always excluded from the search (it can never + // be a restorable object). This matches some logic on the Python side. + if (saved_object_node_id != 0 && + (trackable_object->attributes_size() > 0 || + trackable_object->slot_variables_size() > 0)) { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + callback(saved_object_node_id, *trackable_object), "Unable to restore ", + object_name); + } + + for (const auto& trackable_child_ref : trackable_object->children()) { + const auto& local_name = trackable_child_ref.local_name(); + + // Compute the full child name. + std::string child_name; + if (object_name.empty()) { + child_name = local_name; + } else { + child_name = strings::StrCat(object_name, ".", local_name); + } + + // Descend down the trackable graph. + int trackable_child_node_id = trackable_child_ref.node_id(); + if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) { + // Cycle or duplicate detected - ignore this branch. + continue; + } + if (trackable_child_node_id < 0 || + trackable_child_node_id >= trackable_object_graph().nodes_size()) { + return Status( + errors::Code::FAILED_PRECONDITION, + strings::StrCat("Illegal trackable child node id for ", child_name)); + } + const auto* trackable_child = + &trackable_object_graph().nodes(trackable_child_node_id); + + // Descend down the saved object graph. + int saved_child_node_id = -1; + const SavedObject* saved_child = nullptr; + for (const auto& saved_child_ref : saved_object->children()) { + if (saved_child_ref.local_name() == local_name) { + // Found. + saved_child_node_id = saved_child_ref.node_id(); + if (saved_child_node_id >= 0 && + saved_child_node_id < saved_object_graph().nodes_size()) { + saved_child = &saved_object_graph().nodes(saved_child_node_id); + } + break; + } + } + + if (!saved_child) { + return Status( + errors::Code::FAILED_PRECONDITION, + strings::StrCat("Could not find saved object to restore for ", + child_name)); + } + + TF_RETURN_IF_ERROR(RecurseObjectsToRestore( + saved_child, saved_child_node_id, trackable_child, child_name, + seen_trackable_node_ids, callback)); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/bundle_v2.h b/tensorflow/cc/saved_model/bundle_v2.h new file mode 100644 index 00000000000..d376b7b4c88 --- /dev/null +++ b/tensorflow/cc/saved_model/bundle_v2.h @@ -0,0 +1,87 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for loading the persistent representation of a SavedModelV2. +// Please note that this is depended on by code that does not make use of +// the full runtime and its dependencies should be restricted. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ +#define TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ + +#include <string> + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" + +namespace tensorflow { + +/// Represents a version 2 SavedModel that is loaded from storage (but not yet +/// loaded into an executable in-memory representation). +class SavedModelV2Bundle { + public: + using RestoreObjectsCallback = + std::function<Status(int, const TrackableObjectGraph::TrackableObject&)>; + + /// Loads persistent representations for a SavedModelV2 from the specified + /// export directory. + static Status Load(const std::string& export_dir, SavedModelV2Bundle* bundle); + + /// MetaGraphDef from the loaded SavedModel. + MetaGraphDef& meta_graph_def() { return meta_graph_def_; } + + /// SavedObjectGraph from the MetaGraphDef. + const SavedObjectGraph& saved_object_graph() { + return meta_graph_def().object_graph_def(); + } + + /// TrackableObjectGraph loaded from the variable_reader() checkpoint. + TrackableObjectGraph& trackable_object_graph() { + return trackable_object_graph_; + } + + /// BundleReader for accessing the variables bundle. + BundleReader* variable_reader() { return variable_reader_.get(); } + + /// The GraphDebugInfo (or nullptr if none). + GraphDebugInfo* debug_info() { return debug_info_.get(); } + + /// Restores objects, invoking the callback with the node id in the + /// saved_object_graph() and the corresponding TrackableObject from the + /// trackable_object_graph(). The callback may use the variable_reader() but + /// must not modify the underlying saved_object_graph(). + Status VisitObjectsToRestore(RestoreObjectsCallback callback); + + private: + Status RecurseObjectsToRestore( + const SavedObject* saved_object, int saved_object_node_id, + const TrackableObjectGraph::TrackableObject* trackable_object, + std::string object_name, + absl::flat_hash_set<int>* seen_trackable_node_ids, + RestoreObjectsCallback callback); + + MetaGraphDef meta_graph_def_; + TrackableObjectGraph trackable_object_graph_; + std::unique_ptr<BundleReader> variable_reader_; + std::unique_ptr<GraphDebugInfo> debug_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ diff --git a/tensorflow/cc/saved_model/bundle_v2_test.cc b/tensorflow/cc/saved_model/bundle_v2_test.cc new file mode 100644 index 00000000000..81aeed90968 --- /dev/null +++ b/tensorflow/cc/saved_model/bundle_v2_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/bundle_v2.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestData[] = "cc/saved_model/testdata"; + +class BundleV2Test : public ::testing::Test { + protected: + BundleV2Test() {} + + void RestoreVarsAndVerify(SavedModelV2Bundle* bundle, + std::vector<std::string> expected_names) { + // Collect saved_node_id, full_name, checkpoint_key into a vector. + using RestoredVarType = std::tuple<int, std::string, std::string>; + std::vector<RestoredVarType> restored_vars; + TF_ASSERT_OK(bundle->VisitObjectsToRestore( + [&](int saved_node_id, + const TrackableObjectGraph::TrackableObject& trackable_object) + -> Status { + for (const auto& attr : trackable_object.attributes()) { + if (attr.name() == "VARIABLE_VALUE") { + restored_vars.emplace_back(saved_node_id, attr.full_name(), + attr.checkpoint_key()); + } + } + return Status::OK(); + })); + + // Should be one of each var name restored. + for (const auto& expected_name : expected_names) { + EXPECT_EQ(1, std::count_if(restored_vars.begin(), restored_vars.end(), + [&](RestoredVarType t) { + return std::get<1>(t) == expected_name; + })); + } + + for (const auto& restored_var : restored_vars) { + // Each restored var should match a SavedObjectGraph node with the same + // variable name. + const auto& saved_node = + bundle->saved_object_graph().nodes(std::get<0>(restored_var)); + EXPECT_EQ(std::get<1>(restored_var), saved_node.variable().name()); + + // And should be able to load it from the tensor_bundle. + Tensor value; + TF_ASSERT_OK( + bundle->variable_reader()->Lookup(std::get<2>(restored_var), &value)); + } + } +}; + +TEST_F(BundleV2Test, LoadsVarsAndArithmeticObjectGraph) { + const string export_dir = io::JoinPath( + testing::TensorFlowSrcRoot(), kTestData, "VarsAndArithmeticObjectGraph"); + + SavedModelV2Bundle bundle; + TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle)); + + // Ensure that there are nodes in the trackable_object_graph. + EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0); + + RestoreVarsAndVerify(&bundle, {"variable_x", "variable_y", "child_variable"}); +} + +TEST_F(BundleV2Test, LoadsCyclicModule) { + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestData, "CyclicModule"); + + SavedModelV2Bundle bundle; + TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle)); + + // Ensure that there are nodes in the trackable_object_graph. + EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0); + + RestoreVarsAndVerify(&bundle, {"MyVariable"}); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 6f00dc324bd..cdc87e3dcd9 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -50,6 +50,9 @@ constexpr char kSavedModelVariablesFilename[] = "variables"; constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op"; +// Key in the TensorBundle for the object graph proto. +constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH"; + } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb b/tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb new file mode 100644 index 00000000000..9e937552355 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb @@ -0,0 +1,235 @@ + +N/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/absl/app.py +m/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py +generate_saved_models.pyb +:AssignVariableOp/MyVariable@__inference__traced_restore_45$ +� +� +A +� +� +GW +/AssignVariableOp@__inference__traced_restore_45$ +� +� +A +� +� +GO +'Identity@__inference__traced_restore_45$ +� +� +A +� +� +GL +$Identity@__inference__traced_save_30$ +� +� +A +� +� +GQ +)Identity_1@__inference__traced_restore_45$ +� +� +A +� +� +GN +&Identity_1@__inference__traced_save_30$ +� +� +A +� +� +GQ +)Identity_2@__inference__traced_restore_45$ +� +� +A +� +� +Gj +BMergeV2Checkpoints/checkpoint_prefixes@__inference__traced_save_30$ +� +� +A +� +� +GV +.MergeV2Checkpoints@__inference__traced_save_30$ +� +� +A +� +� +GK +#NoOp@__inference__traced_restore_45$ +� +� +A +� +� +Ga +9RestoreV2/shape_and_slices@__inference__traced_restore_45$ +� +� +A +� +� +G] +5RestoreV2/tensor_names@__inference__traced_restore_45$ +� +� +A +� +� +GP +(RestoreV2@__inference__traced_restore_45$ +� +� +A +� +� +Gc +;RestoreV2_1/shape_and_slices@__inference__traced_restore_45$ +� +� +A +� +� +G_ +7RestoreV2_1/tensor_names@__inference__traced_restore_45$ +� +� +A +� +� +GR +*RestoreV2_1@__inference__traced_restore_45$ +� +� +A +� +� +Gi +ASaveV2/MyVariable/Read/ReadVariableOp@__inference__traced_save_30$ +� +� +A +� +� +G[ +3SaveV2/shape_and_slices@__inference__traced_save_30$ +� +� +A +� +� +GW +/SaveV2/tensor_names@__inference__traced_save_30$ +� +� +A +� +� +GJ +"SaveV2@__inference__traced_save_30$ +� +� +A +� +� +GR +*SaveV2_1/Const@__inference__traced_save_30$ +� +� +A +� +� +G] +5SaveV2_1/shape_and_slices@__inference__traced_save_30$ +� +� +A +� +� +GY +1SaveV2_1/tensor_names@__inference__traced_save_30$ +� +� +A +� +� +GL +$SaveV2_1@__inference__traced_save_30$ +� +� +A +� +� +GY +1ShardedFilename/shard@__inference__traced_save_30$ +� +� +A +� +� +GS ++ShardedFilename@__inference__traced_save_30$ +� +� +A +� +� +G[ +3ShardedFilename_1/shard@__inference__traced_save_30$ +� +� +A +� +� +GU +-ShardedFilename_1@__inference__traced_save_30$ +� +� +A +� +� +GW +/StringJoin/inputs_1@__inference__traced_save_30$ +� +� +A +� +� +GN +&StringJoin@__inference__traced_save_30$ +� +� +A +� +� +GR +*file_prefix@__inference__traced_restore_45$ +� +� +A +� +� +GO +'file_prefix@__inference__traced_save_30$ +� +� +A +� +� +GN +&num_shards@__inference__traced_save_30$ +� +� +A +� +� +G \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb b/tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb new file mode 100644 index 00000000000..7a0494ee132 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..5aae6287294 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index b/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index new file mode 100644 index 00000000000..3fc2cec2b60 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/debug/saved_model_debug_info.pb b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/debug/saved_model_debug_info.pb new file mode 100644 index 00000000000..2eb233756b0 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/debug/saved_model_debug_info.pb @@ -0,0 +1,506 @@ + +j/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py +m/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py +�/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/signature_serialization.py +i/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py +N/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/absl/app.py +o/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py +./generate_saved_models.py +k/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py +q/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.pyg +;AssignVariableOp/variable_x@__inference__traced_restore_101( +� +� +0 +� +� +6\ +0AssignVariableOp@__inference__traced_restore_101( +� +� +0 +� +� +6i +=AssignVariableOp_1/variable_y@__inference__traced_restore_101( +� +� +0 +� +� +6^ +2AssignVariableOp_1@__inference__traced_restore_101( +� +� +0 +� +� +6m +AAssignVariableOp_2/child_variable@__inference__traced_restore_101( +� +� +0 +� +� +6^ +2AssignVariableOp_2@__inference__traced_restore_101( +� +� +0 +� +� +6T +(Identity@__inference__traced_restore_101( +� +� +0 +� +� +6P +$Identity@__inference__traced_save_80( +� +� +0 +� +� +6J +Identity@__inference_compute_34' +T +� +0 +� +� +6U +)Identity@__inference_signature_wrapper_45( +� +� +0 +� +� +6V +*Identity_1@__inference__traced_restore_101( +� +� +0 +� +� +6R +&Identity_1@__inference__traced_save_80( +� +� +0 +� +� +6V +*Identity_2@__inference__traced_restore_101( +� +� +0 +� +� +6V +*Identity_3@__inference__traced_restore_101( +� +� +0 +� +� +6V +*Identity_4@__inference__traced_restore_101( +� +� +0 +� +� +6n +BMergeV2Checkpoints/checkpoint_prefixes@__inference__traced_save_80( +� +� +0 +� +� +6Z +.MergeV2Checkpoints@__inference__traced_save_80( +� +� +0 +� +� +6P +$NoOp@__inference__traced_restore_101( +� +� +0 +� +� +6f +:RestoreV2/shape_and_slices@__inference__traced_restore_101( +� +� +0 +� +� +6b +6RestoreV2/tensor_names@__inference__traced_restore_101( +� +� +0 +� +� +6U +)RestoreV2@__inference__traced_restore_101( +� +� +0 +� +� +6h +<RestoreV2_1/shape_and_slices@__inference__traced_restore_101( +� +� +0 +� +� +6d +8RestoreV2_1/tensor_names@__inference__traced_restore_101( +� +� +0 +� +� +6W ++RestoreV2_1@__inference__traced_restore_101( +� +� +0 +� +� +6q +ESaveV2/child_variable/Read/ReadVariableOp@__inference__traced_save_80( +� +� +0 +� +� +6_ +3SaveV2/shape_and_slices@__inference__traced_save_80( +� +� +0 +� +� +6[ +/SaveV2/tensor_names@__inference__traced_save_80( +� +� +0 +� +� +6m +ASaveV2/variable_x/Read/ReadVariableOp@__inference__traced_save_80( +� +� +0 +� +� +6m +ASaveV2/variable_y/Read/ReadVariableOp@__inference__traced_save_80( +� +� +0 +� +� +6N +"SaveV2@__inference__traced_save_80( +� +� +0 +� +� +6X +,SaveV2_1/Const_1@__inference__traced_save_80( +� +� +0 +� +� +6a +5SaveV2_1/shape_and_slices@__inference__traced_save_80( +� +� +0 +� +� +6] +1SaveV2_1/tensor_names@__inference__traced_save_80( +� +� +0 +� +� +6P +$SaveV2_1@__inference__traced_save_80( +� +� +0 +� +� +6] +1ShardedFilename/shard@__inference__traced_save_80( +� +� +0 +� +� +6W ++ShardedFilename@__inference__traced_save_80( +� +� +0 +� +� +6_ +3ShardedFilename_1/shard@__inference__traced_save_80( +� +� +0 +� +� +6Y +-ShardedFilename_1@__inference__traced_save_80( +� +� +0 +� +� +6k +?StatefulPartitionedCall/args_2@__inference_signature_wrapper_45( +� +� +0 +� +� +6k +?StatefulPartitionedCall/args_3@__inference_signature_wrapper_45( +� +� +0 +� +� +6k +?StatefulPartitionedCall/args_4@__inference_signature_wrapper_45( +� +� +0 +� +� +6k +?StatefulPartitionedCall/args_5@__inference_signature_wrapper_45( +� +� +0 +� +� +6d +8StatefulPartitionedCall@__inference_signature_wrapper_45( +� +� +0 +� +� +6[ +/StringJoin/inputs_1@__inference__traced_save_80( +� +� +0 +� +� +6R +&StringJoin@__inference__traced_save_80( +� +� +0 +� +� +6C +a@__inference_compute_34' +T +� +0 +� +� +6N +"a@__inference_signature_wrapper_45( +� +� +0 +� +� +6y +2add/ReadVariableOp/resource@__inference_compute_34C +� +� + +� +� +� +� +� +� +�p +)add/ReadVariableOp@__inference_compute_34C +� +� + +� +� +� +� +� +� +�c +add@__inference_compute_34E +� +� + +� +� +� +� +� +� +�{ +4add_1/ReadVariableOp/resource@__inference_compute_34C +� +� + +� +� +� +� +� +� +�r ++add_1/ReadVariableOp@__inference_compute_34C +� +� + +� +� +� +� +� +� +�e +add_1@__inference_compute_34E +� +� + +� +� +� +� +� +� +�g +add_2/y@__inference_compute_34E +� +� + +� +� +� +� +� +� +�e +add_2@__inference_compute_34E +� +� + +� +� +� +� +� +� +�C +b@__inference_compute_34' +T +� +0 +� +� +6N +"b@__inference_signature_wrapper_45( +� +� +0 +� +� +6W ++file_prefix@__inference__traced_restore_101( +� +� +0 +� +� +6S +'file_prefix@__inference__traced_save_80( +� +� +0 +� +� +6c +mul@__inference_compute_34E +� +� + +� +� +� +� +� +� +�R +&num_shards@__inference__traced_save_80( +� +� +0 +� +� +6} +6truediv/ReadVariableOp/resource@__inference_compute_34C +� +� + +� +� +� +� +� +� +�t +-truediv/ReadVariableOp@__inference_compute_34C +� +� + +� +� +� +� +� +� +�g +truediv@__inference_compute_34E +� +� + +� +� +� +� +� +� +� \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/saved_model.pb b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/saved_model.pb new file mode 100644 index 00000000000..c7638afb596 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..307f83f85c4 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index new file mode 100644 index 00000000000..e025dc724ea Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/generate_saved_models.py b/tensorflow/cc/saved_model/testdata/generate_saved_models.py new file mode 100644 index 00000000000..5f39ae0651d --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/generate_saved_models.py @@ -0,0 +1,97 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Standalone utility to generate some test saved models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import app + +from tensorflow.python.compat import v2_compat +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.module import module +from tensorflow.python.ops import variables +from tensorflow.python.saved_model import save_options +from tensorflow.python.saved_model import saved_model + + +class VarsAndArithmeticObjectGraph(module.Module): + """Three vars (one in a sub-module) and compute method.""" + + def __init__(self): + self.x = variables.Variable(1.0, name="variable_x") + self.y = variables.Variable(2.0, name="variable_y") + self.child = module.Module() + self.child.z = variables.Variable(3.0, name="child_variable") + self.child.c = ops.convert_to_tensor(5.0) + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec((), dtypes.float32), + tensor_spec.TensorSpec((), dtypes.float32) + ]) + def compute(self, a, b): + return (a + self.x) * (b + self.y) / (self.child.z) + self.child.c + + +class ReferencesParent(module.Module): + + def __init__(self, parent): + super(ReferencesParent, self).__init__() + self.parent = parent + self.my_variable = variables.Variable(3., name="MyVariable") + + +# Creates a cyclic object graph. +class CyclicModule(module.Module): + + def __init__(self): + super(CyclicModule, self).__init__() + self.child = ReferencesParent(self) + + +MODULE_CTORS = { + "VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph, + "CyclicModule": CyclicModule, +} + + +def main(args): + if len(args) != 3: + print("Expected: {export_path} {ModuleName}") + print("Allowed ModuleNames:", MODULE_CTORS.keys()) + return 1 + + _, export_path, module_name = args + module_ctor = MODULE_CTORS.get(module_name) + if not module_ctor: + print("Expected ModuleName to be one of:", MODULE_CTORS.keys()) + return 2 + os.makedirs(export_path) + + tf_module = module_ctor() + options = save_options.SaveOptions(save_debug_info=True) + saved_model.save(tf_module, export_path, options=options) + + +if __name__ == "__main__": + v2_compat.enable_v2_behavior() + app.run(main) diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i index ba5bfb98948..2ecea47b3d3 100644 --- a/tensorflow/compiler/mlir/python/mlir.i +++ b/tensorflow/compiler/mlir/python/mlir.i @@ -83,28 +83,22 @@ string ExperimentalConvertSavedModelToMlir( const string &exported_names_str, bool show_debug_info, TF_Status* status) { - // Load the saved model into a SavedModelBundle. + // Load the saved model into a SavedModelV2Bundle. - // TODO(silvasean): Add support for tags, if needed. - // The default "serve" tag seems to be enough. - std::unordered_set<string> tags; - tags.insert("serve"); - SessionOptions session_options; - RunOptions run_options; - tensorflow::SavedModelBundle bundle; - auto load_status = LoadSavedModel(session_options, run_options, - saved_model_path, tags, &bundle); + tensorflow::SavedModelV2Bundle bundle; + auto load_status = tensorflow::SavedModelV2Bundle::Load( + saved_model_path, &bundle); if (!load_status.ok()) { Set_TF_Status_from_Status(status, load_status); return "// error"; } - // Convert the SavedModelBundle to an MLIR module. + // Convert the SavedModelV2Bundle to an MLIR module. std::vector<string> exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::MLIRContext context; - auto module_or = ConvertSavedModelToMlir(bundle, &context, + auto module_or = ConvertSavedModelToMlir(&bundle, &context, absl::Span<std::string>(exported_names)); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index b1f41b4016e..683c66eced3 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -308,7 +308,7 @@ cc_library( ":mlir_roundtrip_flags", ":tensorflow", ":tensorflow_passes", - "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 77d3fc4bbca..e735dfa1a4c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -1815,7 +1815,7 @@ class SavedModelImporter : public ImporterBase { // Main entry point: converts all functions in the given meta graph to an MLIR // Module. static StatusOr<mlir::OwningModuleRef> Convert( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span<std::string> exported_names, bool add_default_attributes); private: @@ -2015,29 +2015,45 @@ llvm::StringRef ObjectNames::SaveString(const std::string& s) const { return llvm::StringRef(*saved_strings_.insert(s).first); } -StatusOr<Tensor> GetTensorFromSession(Session* session, std::string name) { - std::vector<Tensor> outputs; - TF_RETURN_IF_ERROR(session->Run(/*inputs=*/{}, /*output_tensor_names=*/{name}, - /*target_node_names=*/{}, &outputs)); - return outputs[0]; -} - -// Variable ops return resource types, but we want to read their contents. -// We need to find a "ReadVariableOp" that reads a given variable to get out a -// tensor value. These seem to always be present in the GraphDef's main graph. -// TODO(silvasean): Find a better way to do this. -StatusOr<Tensor> ReadVariableFromSession(const SavedModelBundle& saved_model, - std::string variable_name) { - const GraphDef& graph_def = saved_model.meta_graph_def.graph_def(); - // TODO(silvasean): Don't do linear search. - for (const NodeDef& node : graph_def.node()) { - if (node.op() == "ReadVariableOp" && node.input_size() == 1 && - node.input(0) == variable_name) { - return GetTensorFromSession(saved_model.session.get(), node.name()); +// Extracts a TensorProto for a Const op from a GraphDef, given an op_name. +// Returns nullptr on not found or other mismatch. +// This returns a pointer to the actual node within the graph_def so as to +// avoid expensive copies. +const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def, + const std::string& op_name) { + const NodeDef* match_node = nullptr; + for (const auto& node : graph_def.node()) { + if (node.name() == op_name) { + match_node = &node; } } - return errors::InvalidArgument("Could not find ReadVariableOp reading '", - variable_name, "'"); + + if (!match_node) { + return nullptr; + } + + auto value_it = match_node->attr().find("value"); + if (value_it == match_node->attr().end()) { + return nullptr; + } + + if (!value_it->second.has_tensor()) { + return nullptr; + } + + return &value_it->second.tensor(); +} + +const TrackableObjectGraph::TrackableObject::SerializedTensor* +FindSerializedTensorInTrackable( + const TrackableObjectGraph::TrackableObject& trackable_object, + StringPiece name) { + for (const auto& maybe_serialized_tensor : trackable_object.attributes()) { + if (maybe_serialized_tensor.name() == name) { + return &maybe_serialized_tensor; + } + } + return nullptr; } Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph, @@ -2237,9 +2253,22 @@ Status CreateSavedModelIR( const ObjectNames& object_names, mlir::ModuleOp module, const SavedObjectGraph& object_graph, const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name, - const SavedModelBundle& saved_model) { + SavedModelV2Bundle* saved_model) { mlir::OpBuilder builder(module.getBodyRegion()); mlir::SymbolTable symbol_table(module); + + // Create a side data-structure, indexed by the object_graph node_id to + // a TrackableObject that is restorable. + absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*> + restored_objects; + TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore( + [&](int saved_node_id, + const TrackableObjectGraph::TrackableObject& trackable_object) { + restored_objects.insert( + std::make_pair(saved_node_id, &trackable_object)); + return Status::OK(); + })); + for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { const SavedObject& object = object_graph.nodes(node_id); // For correctness, we cannot import functions that don't have exported @@ -2339,8 +2368,27 @@ Status CreateSavedModelIR( } } else if (object.kind_case() == SavedObject::kVariable) { const SavedVariable& variable = object.variable(); - TF_ASSIGN_OR_RETURN( - Tensor value, ReadVariableFromSession(saved_model, variable.name())); + // Find the trackable in the side data structure. + auto variable_trackable_it = restored_objects.find(node_id); + if (variable_trackable_it == restored_objects.end()) { + return errors::FailedPrecondition("Could not restore saved variable: ", + variable.name()); + } + const auto* serialized_tensor_attr = FindSerializedTensorInTrackable( + *variable_trackable_it->second, "VARIABLE_VALUE"); + if (!serialized_tensor_attr) { + return errors::FailedPrecondition( + "Could not find serialized tensor for saved variable: ", + variable.name()); + } + const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key(); + + // Load it from the reader. + Tensor value; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + saved_model->variable_reader()->Lookup(checkpoint_key, &value), + "Could not read checkpoint key from variables bundle: ", + checkpoint_key); TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); // A variable can have a partially known type, such as tensor<?x27x?xf32>, // even if the initializer is a specific static shape. @@ -2358,10 +2406,15 @@ Status CreateSavedModelIR( builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); } else if (object.kind_case() == SavedObject::kConstant) { const SavedConstant& constant = object.constant(); - TF_ASSIGN_OR_RETURN(Tensor value, - GetTensorFromSession(saved_model.session.get(), - constant.operation())); - TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); + const TensorProto* value = ExtractConstTensorFromGraph( + saved_model->meta_graph_def().graph_def(), constant.operation()); + if (!value) { + return errors::FailedPrecondition( + "Unable to find const node referenced in object graph: ", + constant.operation()); + } + TF_ASSIGN_OR_RETURN(auto value_attr, + ConvertTensorProto(*value, &builder)); auto op = builder.create<mlir::tf_saved_model::GlobalTensorOp>( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), @@ -2379,18 +2432,18 @@ Status CreateSavedModelIR( } StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span<std::string> exported_names, bool add_default_attributes) { GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = - saved_model.debug_info ? *saved_model.debug_info : dummy_debug_info; + saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; GraphImportConfig specs; mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map<std::string, std::string> tf_name_to_mlir_name; - const auto& graphdef = saved_model.meta_graph_def.graph_def(); + const auto& graphdef = saved_model->meta_graph_def().graph_def(); GraphConstructorOptions options; options.allow_internal_ops = true; options.add_default_attributes = add_default_attributes; @@ -2413,11 +2466,11 @@ StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert( TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); } - if (!saved_model.meta_graph_def.has_object_graph_def()) { + if (!saved_model->meta_graph_def().has_object_graph_def()) { return errors::InvalidArgument( "SavedModel does not have an object graph. Please use TF2."); } - auto& object_graph = saved_model.meta_graph_def.object_graph_def(); + auto& object_graph = saved_model->meta_graph_def().object_graph_def(); ObjectNames object_names(object_graph, exported_names); // Clean up a couple func's that always seem to be present when importing a @@ -2484,7 +2537,7 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir( } StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span<std::string> exported_names, bool add_default_attributes) { return SavedModelImporter::Convert(saved_model, context, exported_names, add_default_attributes); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 4f9b47795e2..d4b17073bd5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -20,7 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir -#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -47,7 +47,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir( // Given a SavedModel, returns a MLIR module containing the functions, expressed // with tf_executor dialect. stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span<std::string> exported_names, bool add_default_attributes = true); // Serialize a MLIR module to a string. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index ff131cf2ec3..cd422a66bc5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -109,20 +109,16 @@ mlir::OwningModuleRef SavedModelToMlirImport( absl::string_view saved_model_dir, const std::unordered_set<std::string>& tags, absl::Span<std::string> exported_names, mlir::MLIRContext* context) { - SessionOptions session_options; - RunOptions run_options; - tensorflow::SavedModelBundle bundle; - auto load_status = LoadSavedModel( - session_options, run_options, - std::string(saved_model_dir.data(), saved_model_dir.length()), tags, - &bundle); + tensorflow::SavedModelV2Bundle bundle; + auto load_status = tensorflow::SavedModelV2Bundle::Load( + std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle); if (!load_status.ok()) { LOG(ERROR) << "Failed to load saved model '" << saved_model_dir << "': " << load_status; return nullptr; } - auto module_or = ConvertSavedModelToMlir(bundle, context, exported_names); + auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names); if (!module_or.status().ok()) { LOG(ERROR) << "SavedModel import failed: " << module_or.status(); return nullptr; diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 7a82c5ead22..9d007935314 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -626,6 +626,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:device_util", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:session_options", "@llvm//:support", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index ff7fb2bbf6d..2e13921520a 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow {