From ae07e29a16a04c4127f6ce60d40de964b9bcc600 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 25 Nov 2019 14:54:44 -0800 Subject: [PATCH] Load SavedModelV2 variables directly from the checkpoint. * This fixes a correctness problem where we were relying on the variable's "name" attribute, which is merely advisory. * Should match the load heuristics in Python by reading the TrackableObjectGraph from the checkpoint, re-associating it with the SavedObjectGraph and using that to restore the variable. * Has a side-effect of eliminating the dependency of the importer on the CPU runtime and kernels, which should reduce necessary dependencies to compile by tens of megabytes. Open questions: * I created a new bundle_v2.h because V2 is so substantially different from V1. Also, it has a different dependency surface area and I therefore wanted to use it as a different (lighter) library. Advise if you would like this organized differently. I tried to factor it so that a future C++ SavedModel loader/runner could be built on this as well (with some incremental work). * There didn't seem to be a facility for generating the testdata SavedModels. I ended up just writing a small standalone python script that I can invoke from the command line to generate the new ones. It isn't wired up in any way but it should be possible to do so at some point (I'm not familiar enough with the test infra on this side of the tree, so please advise). * I could see factoring the SavedModelV2Bundle::RestoreObjects method into a dedicated facility for loading arbitrary checkpoints. Not sure what the roadmap here is so just kept is simple for now. PiperOrigin-RevId: 282439157 Change-Id: I685df47ab621917eed270319cde220498f191b74 --- tensorflow/cc/saved_model/BUILD | 55 ++ tensorflow/cc/saved_model/bundle_v2.cc | 223 ++++++++ tensorflow/cc/saved_model/bundle_v2.h | 87 +++ tensorflow/cc/saved_model/bundle_v2_test.cc | 99 ++++ tensorflow/cc/saved_model/constants.h | 3 + .../debug/saved_model_debug_info.pb | 235 ++++++++ .../testdata/CyclicModule/saved_model.pb | Bin 0 -> 7016 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 148 bytes .../CyclicModule/variables/variables.index | Bin 0 -> 205 bytes .../debug/saved_model_debug_info.pb | 506 ++++++++++++++++++ .../saved_model.pb | Bin 0 -> 12498 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 268 bytes .../variables/variables.index | Bin 0 -> 284 bytes .../testdata/generate_saved_models.py | 97 ++++ tensorflow/compiler/mlir/python/mlir.i | 18 +- tensorflow/compiler/mlir/tensorflow/BUILD | 2 +- .../mlir/tensorflow/translate/import_model.cc | 123 +++-- .../mlir/tensorflow/translate/import_model.h | 4 +- .../tensorflow/translate/tf_mlir_translate.cc | 12 +- tensorflow/compiler/tf2xla/BUILD | 1 + .../compiler/tf2xla/mlir_bridge_pass.cc | 1 + 21 files changed, 1408 insertions(+), 58 deletions(-) create mode 100644 tensorflow/cc/saved_model/bundle_v2.cc create mode 100644 tensorflow/cc/saved_model/bundle_v2.h create mode 100644 tensorflow/cc/saved_model/bundle_v2_test.cc create mode 100644 tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb create mode 100644 tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb create mode 100644 tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.data-00000-of-00001 create mode 100644 tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index create mode 100644 tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/debug/saved_model_debug_info.pb create mode 100644 tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/saved_model.pb create mode 100644 tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.data-00000-of-00001 create mode 100644 tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index create mode 100644 tensorflow/cc/saved_model/testdata/generate_saved_models.py diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 31078e92ce5..e9b69501c8a 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -118,6 +118,39 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "bundle_v2", + srcs = ["bundle_v2.cc"], + hdrs = ["bundle_v2.h"], + deps = [ + ":constants", + "@com_google_absl//absl/container:flat_hash_set", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:strcat", + "//tensorflow/core/util/tensor_bundle", + ] + if_not_mobile([ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ]), +) + +tf_cc_test( + name = "bundle_v2_test", + srcs = ["bundle_v2_test.cc"], + data = [ + ":saved_model_half_plus_two", + ], + linkstatic = 1, + deps = [ + ":bundle_v2", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/platform:test", + ], +) + tf_cc_test( name = "saved_model_bundle_test", srcs = ["saved_model_bundle_test.cc"], @@ -160,6 +193,26 @@ tf_cc_test( ], ) +# A subset of the TF2 saved models can be generated with this tool. +py_binary( + name = "testdata/generate_saved_models", + srcs = ["testdata/generate_saved_models.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:variables", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/module", + "//tensorflow/python/saved_model", + "//tensorflow/python/saved_model:save_options", + "@absl_py//absl:app", + ], +) + # TODO(b/32673259): add a test to continuously validate these files. filegroup( name = "saved_model_half_plus_two", @@ -169,5 +222,7 @@ filegroup( "testdata/half_plus_two/**", "testdata/half_plus_two_v2/**", "testdata/x_plus_y_v2_debuginfo/**", + "testdata/CyclicModule/**", + "testdata/VarsAndArithmeticObjectGraph/**", ]), ) diff --git a/tensorflow/cc/saved_model/bundle_v2.cc b/tensorflow/cc/saved_model/bundle_v2.cc new file mode 100644 index 00000000000..b6daece84ab --- /dev/null +++ b/tensorflow/cc/saved_model/bundle_v2.cc @@ -0,0 +1,223 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/bundle_v2.h" + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" + +namespace tensorflow { + +namespace { + +Status ReadSavedModelProto(const string& export_dir, + SavedModel* saved_model_proto) { + LOG(INFO) << "Reading SavedModel from: " << export_dir; + + const string saved_model_pb_path = + io::JoinPath(export_dir, kSavedModelFilenamePb); + if (Env::Default()->FileExists(saved_model_pb_path).ok()) { + return ReadBinaryProto(Env::Default(), saved_model_pb_path, + saved_model_proto); + } + const string saved_model_pbtxt_path = + io::JoinPath(export_dir, kSavedModelFilenamePbTxt); + if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { + return ReadTextProto(Env::Default(), saved_model_pbtxt_path, + saved_model_proto); + } + return Status(error::Code::NOT_FOUND, + "Could not find SavedModel .pb or .pbtxt at supplied export " + "directory path: " + + export_dir); +} + +Status ReadSavedModelDebugInfoIfPresent( + const string& export_dir, + std::unique_ptr* debug_info_proto) { + LOG(INFO) << "Reading SavedModel debug info (if present) from: " + << export_dir; + + const string debug_info_pb_path = + io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb"); + if (Env::Default()->FileExists(debug_info_pb_path).ok()) { + GraphDebugInfo debug_info; + TF_RETURN_IF_ERROR( + ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); + *debug_info_proto = + absl::make_unique(std::move(debug_info)); + } + return Status::OK(); +} + +Status ReadCheckpointObjectGraph(BundleReader* bundle_reader, + TrackableObjectGraph* object_graph) { + Tensor object_graph_tensor; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor), + "SavedModel checkpoint does not contain object graph."); + if (object_graph_tensor.dtype() != DT_STRING || + object_graph_tensor.dims() != 0 || + object_graph_tensor.NumElements() != 1) { + return Status( + error::Code::FAILED_PRECONDITION, + "SavedModel checkpoint object graph was not the correct type."); + } + + const tstring* object_graph_string = reinterpret_cast( + object_graph_tensor.tensor_data().data()); + if (!object_graph->ParseFromString(*object_graph_string)) { + return Status( + error::Code::FAILED_PRECONDITION, + "SavedModel checkpoint object graph could not be deserialized."); + } + return Status::OK(); +} + +} // namespace + +Status SavedModelV2Bundle::Load(const std::string& export_dir, + SavedModelV2Bundle* const bundle) { + SavedModel saved_model_proto; + TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto)); + + // Load MetaGraphDef. + // In version 2 SavedModels, there is only one MetaGraphDef. + if (saved_model_proto.meta_graphs_size() != 1) { + return Status( + error::Code::INVALID_ARGUMENT, + strings::StrCat( + "SavedModelV2 should have exactly one MetaGraphDef but actually ", + "contains ", saved_model_proto.meta_graphs_size())); + } + bundle->meta_graph_def_ = + std::move(*saved_model_proto.mutable_meta_graphs(0)); + + // Load GraphDebugInfo. + TF_RETURN_IF_ERROR( + ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_)); + + // Load the variables checkpoint reader. + const std::string variables_prefix = io::JoinPath( + export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename); + bundle->variable_reader_.reset( + new BundleReader(Env::Default(), variables_prefix)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + bundle->variable_reader_->status(), + "Unable to load SavedModel variables checkpoint from ", variables_prefix); + + // Deserialize the object graph proto from the tensor bundle. + TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph( + bundle->variable_reader_.get(), &bundle->trackable_object_graph_)); + return Status::OK(); +} + +Status SavedModelV2Bundle::VisitObjectsToRestore( + RestoreObjectsCallback callback) { + if (saved_object_graph().nodes_size() == 0 || + trackable_object_graph().nodes_size() == 0) { + return Status::OK(); + } + + // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph + // and descend to leaves. Note that the TrackableObjectGraph can have cycles + // (as can the SavedObjectGraph). + // This is detected and cycle edges are skipped. + const SavedObject* root_saved_object = &saved_object_graph().nodes(0); + const TrackableObjectGraph::TrackableObject* root_trackable_object = + &trackable_object_graph().nodes(0); + absl::flat_hash_set trackable_node_ids; + return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object, + std::string(), &trackable_node_ids, + std::move(callback)); +} + +Status SavedModelV2Bundle::RecurseObjectsToRestore( + const SavedObject* saved_object, int saved_object_node_id, + const TrackableObjectGraph::TrackableObject* trackable_object, + std::string object_name, absl::flat_hash_set* seen_trackable_node_ids, + RestoreObjectsCallback callback) { + // Callback if any attributes or slot variables. + // Note that the root is always excluded from the search (it can never + // be a restorable object). This matches some logic on the Python side. + if (saved_object_node_id != 0 && + (trackable_object->attributes_size() > 0 || + trackable_object->slot_variables_size() > 0)) { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + callback(saved_object_node_id, *trackable_object), "Unable to restore ", + object_name); + } + + for (const auto& trackable_child_ref : trackable_object->children()) { + const auto& local_name = trackable_child_ref.local_name(); + + // Compute the full child name. + std::string child_name; + if (object_name.empty()) { + child_name = local_name; + } else { + child_name = strings::StrCat(object_name, ".", local_name); + } + + // Descend down the trackable graph. + int trackable_child_node_id = trackable_child_ref.node_id(); + if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) { + // Cycle or duplicate detected - ignore this branch. + continue; + } + if (trackable_child_node_id < 0 || + trackable_child_node_id >= trackable_object_graph().nodes_size()) { + return Status( + errors::Code::FAILED_PRECONDITION, + strings::StrCat("Illegal trackable child node id for ", child_name)); + } + const auto* trackable_child = + &trackable_object_graph().nodes(trackable_child_node_id); + + // Descend down the saved object graph. + int saved_child_node_id = -1; + const SavedObject* saved_child = nullptr; + for (const auto& saved_child_ref : saved_object->children()) { + if (saved_child_ref.local_name() == local_name) { + // Found. + saved_child_node_id = saved_child_ref.node_id(); + if (saved_child_node_id >= 0 && + saved_child_node_id < saved_object_graph().nodes_size()) { + saved_child = &saved_object_graph().nodes(saved_child_node_id); + } + break; + } + } + + if (!saved_child) { + return Status( + errors::Code::FAILED_PRECONDITION, + strings::StrCat("Could not find saved object to restore for ", + child_name)); + } + + TF_RETURN_IF_ERROR(RecurseObjectsToRestore( + saved_child, saved_child_node_id, trackable_child, child_name, + seen_trackable_node_ids, callback)); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/bundle_v2.h b/tensorflow/cc/saved_model/bundle_v2.h new file mode 100644 index 00000000000..d376b7b4c88 --- /dev/null +++ b/tensorflow/cc/saved_model/bundle_v2.h @@ -0,0 +1,87 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for loading the persistent representation of a SavedModelV2. +// Please note that this is depended on by code that does not make use of +// the full runtime and its dependencies should be restricted. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ +#define TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" + +namespace tensorflow { + +/// Represents a version 2 SavedModel that is loaded from storage (but not yet +/// loaded into an executable in-memory representation). +class SavedModelV2Bundle { + public: + using RestoreObjectsCallback = + std::function; + + /// Loads persistent representations for a SavedModelV2 from the specified + /// export directory. + static Status Load(const std::string& export_dir, SavedModelV2Bundle* bundle); + + /// MetaGraphDef from the loaded SavedModel. + MetaGraphDef& meta_graph_def() { return meta_graph_def_; } + + /// SavedObjectGraph from the MetaGraphDef. + const SavedObjectGraph& saved_object_graph() { + return meta_graph_def().object_graph_def(); + } + + /// TrackableObjectGraph loaded from the variable_reader() checkpoint. + TrackableObjectGraph& trackable_object_graph() { + return trackable_object_graph_; + } + + /// BundleReader for accessing the variables bundle. + BundleReader* variable_reader() { return variable_reader_.get(); } + + /// The GraphDebugInfo (or nullptr if none). + GraphDebugInfo* debug_info() { return debug_info_.get(); } + + /// Restores objects, invoking the callback with the node id in the + /// saved_object_graph() and the corresponding TrackableObject from the + /// trackable_object_graph(). The callback may use the variable_reader() but + /// must not modify the underlying saved_object_graph(). + Status VisitObjectsToRestore(RestoreObjectsCallback callback); + + private: + Status RecurseObjectsToRestore( + const SavedObject* saved_object, int saved_object_node_id, + const TrackableObjectGraph::TrackableObject* trackable_object, + std::string object_name, + absl::flat_hash_set* seen_trackable_node_ids, + RestoreObjectsCallback callback); + + MetaGraphDef meta_graph_def_; + TrackableObjectGraph trackable_object_graph_; + std::unique_ptr variable_reader_; + std::unique_ptr debug_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ diff --git a/tensorflow/cc/saved_model/bundle_v2_test.cc b/tensorflow/cc/saved_model/bundle_v2_test.cc new file mode 100644 index 00000000000..81aeed90968 --- /dev/null +++ b/tensorflow/cc/saved_model/bundle_v2_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/bundle_v2.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestData[] = "cc/saved_model/testdata"; + +class BundleV2Test : public ::testing::Test { + protected: + BundleV2Test() {} + + void RestoreVarsAndVerify(SavedModelV2Bundle* bundle, + std::vector expected_names) { + // Collect saved_node_id, full_name, checkpoint_key into a vector. + using RestoredVarType = std::tuple; + std::vector restored_vars; + TF_ASSERT_OK(bundle->VisitObjectsToRestore( + [&](int saved_node_id, + const TrackableObjectGraph::TrackableObject& trackable_object) + -> Status { + for (const auto& attr : trackable_object.attributes()) { + if (attr.name() == "VARIABLE_VALUE") { + restored_vars.emplace_back(saved_node_id, attr.full_name(), + attr.checkpoint_key()); + } + } + return Status::OK(); + })); + + // Should be one of each var name restored. + for (const auto& expected_name : expected_names) { + EXPECT_EQ(1, std::count_if(restored_vars.begin(), restored_vars.end(), + [&](RestoredVarType t) { + return std::get<1>(t) == expected_name; + })); + } + + for (const auto& restored_var : restored_vars) { + // Each restored var should match a SavedObjectGraph node with the same + // variable name. + const auto& saved_node = + bundle->saved_object_graph().nodes(std::get<0>(restored_var)); + EXPECT_EQ(std::get<1>(restored_var), saved_node.variable().name()); + + // And should be able to load it from the tensor_bundle. + Tensor value; + TF_ASSERT_OK( + bundle->variable_reader()->Lookup(std::get<2>(restored_var), &value)); + } + } +}; + +TEST_F(BundleV2Test, LoadsVarsAndArithmeticObjectGraph) { + const string export_dir = io::JoinPath( + testing::TensorFlowSrcRoot(), kTestData, "VarsAndArithmeticObjectGraph"); + + SavedModelV2Bundle bundle; + TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle)); + + // Ensure that there are nodes in the trackable_object_graph. + EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0); + + RestoreVarsAndVerify(&bundle, {"variable_x", "variable_y", "child_variable"}); +} + +TEST_F(BundleV2Test, LoadsCyclicModule) { + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestData, "CyclicModule"); + + SavedModelV2Bundle bundle; + TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle)); + + // Ensure that there are nodes in the trackable_object_graph. + EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0); + + RestoreVarsAndVerify(&bundle, {"MyVariable"}); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 6f00dc324bd..cdc87e3dcd9 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -50,6 +50,9 @@ constexpr char kSavedModelVariablesFilename[] = "variables"; constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op"; +// Key in the TensorBundle for the object graph proto. +constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH"; + } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb b/tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb new file mode 100644 index 00000000000..9e937552355 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb @@ -0,0 +1,235 @@ + +N/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/absl/app.py +m/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py +generate_saved_models.pyb +:AssignVariableOp/MyVariable@__inference__traced_restore_45$ +Ö +• +A +ú +« +GW +/AssignVariableOp@__inference__traced_restore_45$ +Ö +• +A +ú +« +GO +'Identity@__inference__traced_restore_45$ +Ö +• +A +ú +« +GL +$Identity@__inference__traced_save_30$ +Ö +• +A +ú +« +GQ +)Identity_1@__inference__traced_restore_45$ +Ö +• +A +ú +« +GN +&Identity_1@__inference__traced_save_30$ +Ö +• +A +ú +« +GQ +)Identity_2@__inference__traced_restore_45$ +Ö +• +A +ú +« +Gj +BMergeV2Checkpoints/checkpoint_prefixes@__inference__traced_save_30$ +Ö +• +A +ú +« +GV +.MergeV2Checkpoints@__inference__traced_save_30$ +Ö +• +A +ú +« +GK +#NoOp@__inference__traced_restore_45$ +Ö +• +A +ú +« +Ga +9RestoreV2/shape_and_slices@__inference__traced_restore_45$ +Ö +• +A +ú +« +G] +5RestoreV2/tensor_names@__inference__traced_restore_45$ +Ö +• +A +ú +« +GP +(RestoreV2@__inference__traced_restore_45$ +Ö +• +A +ú +« +Gc +;RestoreV2_1/shape_and_slices@__inference__traced_restore_45$ +Ö +• +A +ú +« +G_ +7RestoreV2_1/tensor_names@__inference__traced_restore_45$ +Ö +• +A +ú +« +GR +*RestoreV2_1@__inference__traced_restore_45$ +Ö +• +A +ú +« +Gi +ASaveV2/MyVariable/Read/ReadVariableOp@__inference__traced_save_30$ +Ö +• +A +ú +« +G[ +3SaveV2/shape_and_slices@__inference__traced_save_30$ +Ö +• +A +ú +« +GW +/SaveV2/tensor_names@__inference__traced_save_30$ +Ö +• +A +ú +« +GJ +"SaveV2@__inference__traced_save_30$ +Ö +• +A +ú +« +GR +*SaveV2_1/Const@__inference__traced_save_30$ +Ö +• +A +ú +« +G] +5SaveV2_1/shape_and_slices@__inference__traced_save_30$ +Ö +• +A +ú +« +GY +1SaveV2_1/tensor_names@__inference__traced_save_30$ +Ö +• +A +ú +« +GL +$SaveV2_1@__inference__traced_save_30$ +Ö +• +A +ú +« +GY +1ShardedFilename/shard@__inference__traced_save_30$ +Ö +• +A +ú +« +GS ++ShardedFilename@__inference__traced_save_30$ +Ö +• +A +ú +« +G[ +3ShardedFilename_1/shard@__inference__traced_save_30$ +Ö +• +A +ú +« +GU +-ShardedFilename_1@__inference__traced_save_30$ +Ö +• +A +ú +« +GW +/StringJoin/inputs_1@__inference__traced_save_30$ +Ö +• +A +ú +« +GN +&StringJoin@__inference__traced_save_30$ +Ö +• +A +ú +« +GR +*file_prefix@__inference__traced_restore_45$ +Ö +• +A +ú +« +GO +'file_prefix@__inference__traced_save_30$ +Ö +• +A +ú +« +GN +&num_shards@__inference__traced_save_30$ +Ö +• +A +ú +« +G \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb b/tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..7a0494ee132d787fecb576c2aa9a5777c3992cbb GIT binary patch literal 7016 zcmc&(+ix3L8TWDQ*!djia*{OeNl3fi?Y8UP)}Cpai-izz6t)fBCQ8x>@sQEfbDRw8 z*k)!NZNvjYJismz0<_NvMOp+B0v`66C*FC0|6u1b2mI z>~$E0%(L!J4>qB&+g6+Wq}$fG5xr;lSa}w*=84+l)cwaY^>m5Aw_#$BtJ&P3xX5*?ad0a_KCFX2*cEXr|TgG;zadv)gIu$8^-R z483#A(lk|OFz(JOJ)>)Nht?LL$UouDzST7pn~b<<3-CiIp!ko~j%LfxhU72L!x&y{ zsd|SS;=%GwD4;qGt|=Y$bm-2Iy~?OUQSnH!oHe;|#w*kE1L;BSzQ)hwnzSxSwH5is zne;%y>HE^^^262p$19Jd)h6d_g#N1vCjcLvd#0iy*OePEp|~(o9AYNTK)Rl)6g=#t zYsCI=Ik8t3h6Na>!rh?4Qb4I{BgFaR(MFhrV8evdS7HXnT^Lgr%wLg#KTy{#%EcgKOy);u;q+WN8+Zex;3?gPiNlS(CmW5YTgu_a)8|{FqSyz{tySSr^})u$ z!QPX`^MkGZ>hMX}feBMRs*mbIgD}_WNO1X5=4y)GPdV(b0}YsflG=9#-x( z-u1DvsBAvo+WhYB&XZ>cP6pd)Y;SEIC?D=^>^_!RGB@mV0*fN6VGO*fS0+Gv>2j*< z?*zoKkqGa%MBd{8en;FA7SA^|gCl?6Yr{P!7l+tp>tgW}zBU$I@}TTzKLYUZi1{n< zmLcz+J_?6kmOZk&hA^#xJLQrRvOF%5V?QIprz}vP33v)r7sG|1O*S5AR*^i%z8WWT z`7IFg4GqZ4KKaT`C}LfJot0M-+X*6Nha>r&sB=c6f_y`WQ3BG?Z-y`^e$?O3r|4a{ z1G5U&Ifi2PcvElbUYk&)qM6{|^YBNkPvf%opnRc`O#B|)vDeOIA^$1{R_M+jFC%w* zB|@oNZsJx+!7j$MN;&Mp0=-FtJWW61#xXyXH&1x;<6c+qSmp&8-A5{Rn$Eie%+X&E znAvxlu5H{@b#!YY2;ZP&|E3qU9yTPWUVmS)_-RjBU0q#y$RF|LW@~v_k{_+FXtkQW zT&um1*H&BFN)=0XOP-76hy4Um8t_Sl@o)}Y3^w_~`0$DGlxH+uN?+zF!F}xe0GVTx9VwW~jp`J>{8aZr8QgxEHy)d|q{TSxyc}ob9v{$$t^3B9g79bt<7*+#5`<7>^e*HBX0(2X+2 zuw9)W%%e$r{C{6e-b^IO{z+gE1?b>EVE(cd?T{k31S61t`^tAXo$_0 z(RveOzsmsosjcE+_rX#OwQXqE>p(=dbsJBK{fWSIWbKMhbUSkJI^Ad{hB0>Q#dJg(W;fPF;@ z3xL(Do0I)AHu+T`E#JE;>$Y>o)I>C9j#aOwjlei!0eFdMv0+V~jpK_MLBlJQZGKoWWY$=7yL50I*-5Dqc+ zjI*&k<>d~~htUqcOHhCGO0|DEoBs=$8fV)uWn#z*o`mZ)n)F_YIW}oYk=D}XTQFna zr75RfjklFk{M7X9Pp8lj*I<2Aod7$;DPjNs literal 0 HcmV?d00001 diff --git a/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index b/tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..3fc2cec2b60b94e2da079b69e8d035962dd297cd GIT binary patch literal 205 zcmZQzVB=tvV&Y(Akl~AW_HcFf4)FK%3vqPvagFzP@^Wb?z*Bqoc*r)#A~WELr0yl_$1m=1E>Ogc6NnCw3ECNV4dL9&cuzc*>6`i5z#=glJ_yPe_)ngNK|Hu6m#a#{{$g?lhESXhr6g75L&upeeI-$ZcEKacWUQ$)j+E56 zcqMx!bGf7(iJ9zLHk-L7P9J5jWO4a&_WJ6|^~(p>u4bW3#s(&!H5qLSO#yr0YT zgG)%g6*iE4HJpDf92aOP{~rPuA%>XAG5!iu(t*KvgvL~%!Du{6p;-Kv7rJK_3`b+n3+Wmpe9jC9X@9b>d&fnj;xt(@G_ISI)<~rTp7LdLeC9gmT zZqesGWVlquG_|K+>2~B6jPnSC(NQWX`*OR|l4NSXE&7qGGhTqtnT6Ay`MYrN-4S%D zFfVJ~rGh;k*`7PhEl%C0$Mwqy;^F`ylarGqHKJcdF^GtSq%Og@bkLAXWu@AZYIW@@ zIZLNpci#PKyQ!1{<(qO{RE#D-fXEA*_s2r~0j++Fh~Yb-8zgm+-+E7NCE>gzm8<(o zL#Y-O$+TMOaYL@xm4>u>74E@kQ$E7y{c=T7FQg~cw&1uO;2&X_3AokKR|9q9qJs6f zMQr=Nfmt_pGphQDil5f6+XpPxUh8&X6-7JS`EzGN)lI@ zbRz@rbP=}Lzp!mra5s}rozdhJhXFAvps9cKVp`^6y0Enf3w~n40FAkuB4smK`hFgM zCBQ#PZ=awsAaa86Uj_V?g!!hPrPi%1wTM}XsBOicEKS?L32wGWu93BuFNatT98 zxz$M!%n46Opi{jb`2bOp$cz+N51xDMuk_eo?XiC~^>GZV3*6Y)RsyE&##XEyVhyH9 zE34v6pt#wWAbwld8a1ET`G827(=zg0h9&j_KvDvb9mrHQu^bX-XI290$mD0l5GRtwviZm5;gr4S~SV z5hZj)AGf5q=_-Cs$fSzdvplV1b8WCMg3l?`T|#x^H^@931)hld_EO8XA%d$kwq7dj ziVOn-^8)46mIEf-J?#4`H>uB%(R^gFUFI(+TRv6d(@v#FC z^c^Z-Nnd%se8dE#QDtr3|E2yVhiR971iPyVQ8d&tq;IH*3}@>tL0uOj`zEnspe^jG z+7Q)4C-zV?E~T?6z9?ng%V+oeS?x{GhaZZU{T!D1II(!i-vy|z6N;H2BcK7s@gW8H zzsl%;)(E)Xr}qnu^O&jF&mROAJ9$!=w-O5pXzLr z1b|P%m=2zmu;JHiN!f&eElV(xH2$}mrqP#`#(}acZah$mkLtB@wbeW!1KY^Ku6RP? z+Jls3KO``*t#&ByAW&Mpmz$^Sb3ALbIacynwaM{})KU)XlDK*!v$|Hwh=uDbt5=ol zg;HVdT4w)RDZ946wsy6U<{7UzfvtN;4N~Oh7;kS zkR=Y=k+RPMRNIF{XcB=I1J7(!Sr4yd2NniQ{oF@h}|5p7CiihYHv z&gsZ_x1$ytnA)mMWH*uk)3;lvpDPnpU;dcbPY8^AsOZz3&wF0{FdcfrhcI?RFe<4= zfwDT4Q>)!zRXbR%3aPAN?aTS6b!#X2Icej?n;W0nym$NVj@n4Sm%np!V@G;nYkl)Y zH9-22!ku!)BoY;RI=6HY>i7gbV{=bW<)q>Jf(A z?qQdOK{XrH$(OQj2F(0GFE@o$@O2Uvm-{81yQfcnMt#;;x<6@H>TzU4!L0e}hF9iZ ze5~PH*EkSO?0MhwktXjU34z&`t#DEEJ)?Dm>8RDk@-~#1A9B+uZR#~jr(g;&C$o?y z7wvg9ov?OHX{XM)ds%pb+RNhotPq-nk39YKERGpBh~deo3)3yA`ILJGJ6xZDrGO6C z2(@bk{yYU=5BKg?dukot)N1lCIJTCS@hBZ*>zFsS+yGgtFzO)45-fT^M9Ep@sFNF3 z?Xs%ZE@SKJ*}LYn#|feCB`NE-h^5Womez;oWzLZKjFrycvBhEj?TCK?7jv+$e;iUr zY@qma!@f;`SzGQvep+U`B#?96@UmYMuyjMMM`!)EU{P*Q1Wb=(7c2lbYbKC8@|^eN zaicJ7wwkf+?3dFG*~n#|_6>m(CM`LQ*rvVZmhfQ3nJA9c-YjQt67bkN^{lXN4#~VR zG36WI?Kve8sTVQ(ygdNgqOK38StFd1R4}KGfXcpw_nz00eYou&U6^ku?t^2z-n{4? zJd53V7N3-7k$u;XXVHh}-r>^5w`kgWKbrRATV(GLNSFrv(I?(4m>M3^=Io@pJ^zV= zT7T*He0Qwuw*>50Okhx3wA1b%)|TubZuU%NGuSiX+(9nT3Gqw1E8`wWY(@ySV zecejbr#pH3(8`!rY|dlPL)iG9+wCxCb1&f)?XxJd*dCuFb(3Xgp!3)JQsB6C#QMJ^9o9-pMLNWpTAZVz*x}Kb&29dib;&PcW!a{t%;J07U_USVa*C%T;bRlpmN&{w0 z3qt#q+T*m2FpXb2O66iJ_g)Nk`X5OgiVsjhpGrv1VOesZ;0$uPsC|I16&@Pr1Hys3?QM|kpc|f!Dgn~pQM|CqYu|;l!Rk(wsaCn z!t6;T0q{l`BeAUF)t(FLGbb!*P()wtgN)DVXQxjhmmqQy3Bl64h~*4F-Hrtp3i^VI YDw%vYvH&=8Vp{+B8)1~M`6r6xfBgZYLjV8( literal 0 HcmV?d00001 diff --git a/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..307f83f85c4ef7c106c87961ac09b68989f8ba80 GIT binary patch literal 268 zcmZQzXs~BsU~m8;hhL2M$`nMnG`Uzg7=;)sfE1GuVy6&GKaW00q#laFhBn4?dqs}NUNVo_#dQch}og_KN%zMf-9NRX#f yXozdDKB`Irx+|q*D)H%N2fB|{h_Q-`fy+(E1;ZJ9U>C>393!O!=IdACF%kgX+(j7x literal 0 HcmV?d00001 diff --git a/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..e025dc724eac6ac99bfe951ad82a2ad226571e73 GIT binary patch literal 284 zcmZQzVB=tvV&Y(Akl~AW_HcFf4)FK%3vqPvagFzP@^W tags; - tags.insert("serve"); - SessionOptions session_options; - RunOptions run_options; - tensorflow::SavedModelBundle bundle; - auto load_status = LoadSavedModel(session_options, run_options, - saved_model_path, tags, &bundle); + tensorflow::SavedModelV2Bundle bundle; + auto load_status = tensorflow::SavedModelV2Bundle::Load( + saved_model_path, &bundle); if (!load_status.ok()) { Set_TF_Status_from_Status(status, load_status); return "// error"; } - // Convert the SavedModelBundle to an MLIR module. + // Convert the SavedModelV2Bundle to an MLIR module. std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::MLIRContext context; - auto module_or = ConvertSavedModelToMlir(bundle, &context, + auto module_or = ConvertSavedModelToMlir(&bundle, &context, absl::Span(exported_names)); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index b1f41b4016e..683c66eced3 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -308,7 +308,7 @@ cc_library( ":mlir_roundtrip_flags", ":tensorflow", ":tensorflow_passes", - "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 77d3fc4bbca..e735dfa1a4c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -1815,7 +1815,7 @@ class SavedModelImporter : public ImporterBase { // Main entry point: converts all functions in the given meta graph to an MLIR // Module. static StatusOr Convert( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes); private: @@ -2015,29 +2015,45 @@ llvm::StringRef ObjectNames::SaveString(const std::string& s) const { return llvm::StringRef(*saved_strings_.insert(s).first); } -StatusOr GetTensorFromSession(Session* session, std::string name) { - std::vector outputs; - TF_RETURN_IF_ERROR(session->Run(/*inputs=*/{}, /*output_tensor_names=*/{name}, - /*target_node_names=*/{}, &outputs)); - return outputs[0]; -} - -// Variable ops return resource types, but we want to read their contents. -// We need to find a "ReadVariableOp" that reads a given variable to get out a -// tensor value. These seem to always be present in the GraphDef's main graph. -// TODO(silvasean): Find a better way to do this. -StatusOr ReadVariableFromSession(const SavedModelBundle& saved_model, - std::string variable_name) { - const GraphDef& graph_def = saved_model.meta_graph_def.graph_def(); - // TODO(silvasean): Don't do linear search. - for (const NodeDef& node : graph_def.node()) { - if (node.op() == "ReadVariableOp" && node.input_size() == 1 && - node.input(0) == variable_name) { - return GetTensorFromSession(saved_model.session.get(), node.name()); +// Extracts a TensorProto for a Const op from a GraphDef, given an op_name. +// Returns nullptr on not found or other mismatch. +// This returns a pointer to the actual node within the graph_def so as to +// avoid expensive copies. +const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def, + const std::string& op_name) { + const NodeDef* match_node = nullptr; + for (const auto& node : graph_def.node()) { + if (node.name() == op_name) { + match_node = &node; } } - return errors::InvalidArgument("Could not find ReadVariableOp reading '", - variable_name, "'"); + + if (!match_node) { + return nullptr; + } + + auto value_it = match_node->attr().find("value"); + if (value_it == match_node->attr().end()) { + return nullptr; + } + + if (!value_it->second.has_tensor()) { + return nullptr; + } + + return &value_it->second.tensor(); +} + +const TrackableObjectGraph::TrackableObject::SerializedTensor* +FindSerializedTensorInTrackable( + const TrackableObjectGraph::TrackableObject& trackable_object, + StringPiece name) { + for (const auto& maybe_serialized_tensor : trackable_object.attributes()) { + if (maybe_serialized_tensor.name() == name) { + return &maybe_serialized_tensor; + } + } + return nullptr; } Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph, @@ -2237,9 +2253,22 @@ Status CreateSavedModelIR( const ObjectNames& object_names, mlir::ModuleOp module, const SavedObjectGraph& object_graph, const std::unordered_map& tf_name_to_mlir_name, - const SavedModelBundle& saved_model) { + SavedModelV2Bundle* saved_model) { mlir::OpBuilder builder(module.getBodyRegion()); mlir::SymbolTable symbol_table(module); + + // Create a side data-structure, indexed by the object_graph node_id to + // a TrackableObject that is restorable. + absl::flat_hash_map + restored_objects; + TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore( + [&](int saved_node_id, + const TrackableObjectGraph::TrackableObject& trackable_object) { + restored_objects.insert( + std::make_pair(saved_node_id, &trackable_object)); + return Status::OK(); + })); + for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { const SavedObject& object = object_graph.nodes(node_id); // For correctness, we cannot import functions that don't have exported @@ -2339,8 +2368,27 @@ Status CreateSavedModelIR( } } else if (object.kind_case() == SavedObject::kVariable) { const SavedVariable& variable = object.variable(); - TF_ASSIGN_OR_RETURN( - Tensor value, ReadVariableFromSession(saved_model, variable.name())); + // Find the trackable in the side data structure. + auto variable_trackable_it = restored_objects.find(node_id); + if (variable_trackable_it == restored_objects.end()) { + return errors::FailedPrecondition("Could not restore saved variable: ", + variable.name()); + } + const auto* serialized_tensor_attr = FindSerializedTensorInTrackable( + *variable_trackable_it->second, "VARIABLE_VALUE"); + if (!serialized_tensor_attr) { + return errors::FailedPrecondition( + "Could not find serialized tensor for saved variable: ", + variable.name()); + } + const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key(); + + // Load it from the reader. + Tensor value; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + saved_model->variable_reader()->Lookup(checkpoint_key, &value), + "Could not read checkpoint key from variables bundle: ", + checkpoint_key); TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); // A variable can have a partially known type, such as tensor, // even if the initializer is a specific static shape. @@ -2358,10 +2406,15 @@ Status CreateSavedModelIR( builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); } else if (object.kind_case() == SavedObject::kConstant) { const SavedConstant& constant = object.constant(); - TF_ASSIGN_OR_RETURN(Tensor value, - GetTensorFromSession(saved_model.session.get(), - constant.operation())); - TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); + const TensorProto* value = ExtractConstTensorFromGraph( + saved_model->meta_graph_def().graph_def(), constant.operation()); + if (!value) { + return errors::FailedPrecondition( + "Unable to find const node referenced in object graph: ", + constant.operation()); + } + TF_ASSIGN_OR_RETURN(auto value_attr, + ConvertTensorProto(*value, &builder)); auto op = builder.create( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), @@ -2379,18 +2432,18 @@ Status CreateSavedModelIR( } StatusOr SavedModelImporter::Convert( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = - saved_model.debug_info ? *saved_model.debug_info : dummy_debug_info; + saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; GraphImportConfig specs; mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; - const auto& graphdef = saved_model.meta_graph_def.graph_def(); + const auto& graphdef = saved_model->meta_graph_def().graph_def(); GraphConstructorOptions options; options.allow_internal_ops = true; options.add_default_attributes = add_default_attributes; @@ -2413,11 +2466,11 @@ StatusOr SavedModelImporter::Convert( TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); } - if (!saved_model.meta_graph_def.has_object_graph_def()) { + if (!saved_model->meta_graph_def().has_object_graph_def()) { return errors::InvalidArgument( "SavedModel does not have an object graph. Please use TF2."); } - auto& object_graph = saved_model.meta_graph_def.object_graph_def(); + auto& object_graph = saved_model->meta_graph_def().object_graph_def(); ObjectNames object_names(object_graph, exported_names); // Clean up a couple func's that always seem to be present when importing a @@ -2484,7 +2537,7 @@ StatusOr ConvertGraphToMlir( } StatusOr ConvertSavedModelToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { return SavedModelImporter::Convert(saved_model, context, exported_names, add_default_attributes); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 4f9b47795e2..d4b17073bd5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -20,7 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir -#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -47,7 +47,7 @@ stream_executor::port::StatusOr ConvertGraphToMlir( // Given a SavedModel, returns a MLIR module containing the functions, expressed // with tf_executor dialect. stream_executor::port::StatusOr ConvertSavedModelToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context, + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes = true); // Serialize a MLIR module to a string. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index ff131cf2ec3..cd422a66bc5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -109,20 +109,16 @@ mlir::OwningModuleRef SavedModelToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context) { - SessionOptions session_options; - RunOptions run_options; - tensorflow::SavedModelBundle bundle; - auto load_status = LoadSavedModel( - session_options, run_options, - std::string(saved_model_dir.data(), saved_model_dir.length()), tags, - &bundle); + tensorflow::SavedModelV2Bundle bundle; + auto load_status = tensorflow::SavedModelV2Bundle::Load( + std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle); if (!load_status.ok()) { LOG(ERROR) << "Failed to load saved model '" << saved_model_dir << "': " << load_status; return nullptr; } - auto module_or = ConvertSavedModelToMlir(bundle, context, exported_names); + auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names); if (!module_or.status().ok()) { LOG(ERROR) << "SavedModel import failed: " << module_or.status(); return nullptr; diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 7a82c5ead22..9d007935314 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -626,6 +626,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:device_util", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:session_options", "@llvm//:support", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index ff7fb2bbf6d..2e13921520a 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow {