Load SavedModelV2 variables directly from the checkpoint.
* This fixes a correctness problem where we were relying on the variable's "name" attribute, which is merely advisory. * Should match the load heuristics in Python by reading the TrackableObjectGraph from the checkpoint, re-associating it with the SavedObjectGraph and using that to restore the variable. * Has a side-effect of eliminating the dependency of the importer on the CPU runtime and kernels, which should reduce necessary dependencies to compile by tens of megabytes. Open questions: * I created a new bundle_v2.h because V2 is so substantially different from V1. Also, it has a different dependency surface area and I therefore wanted to use it as a different (lighter) library. Advise if you would like this organized differently. I tried to factor it so that a future C++ SavedModel loader/runner could be built on this as well (with some incremental work). * There didn't seem to be a facility for generating the testdata SavedModels. I ended up just writing a small standalone python script that I can invoke from the command line to generate the new ones. It isn't wired up in any way but it should be possible to do so at some point (I'm not familiar enough with the test infra on this side of the tree, so please advise). * I could see factoring the SavedModelV2Bundle::RestoreObjects method into a dedicated facility for loading arbitrary checkpoints. Not sure what the roadmap here is so just kept is simple for now. PiperOrigin-RevId: 282439157 Change-Id: I685df47ab621917eed270319cde220498f191b74
This commit is contained in:
parent
2a16758ee4
commit
ae07e29a16
@ -118,6 +118,39 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bundle_v2",
|
||||
srcs = ["bundle_v2.cc"],
|
||||
hdrs = ["bundle_v2.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bundle_v2_test",
|
||||
srcs = ["bundle_v2_test.cc"],
|
||||
data = [
|
||||
":saved_model_half_plus_two",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":bundle_v2",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/platform:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_bundle_test",
|
||||
srcs = ["saved_model_bundle_test.cc"],
|
||||
@ -160,6 +193,26 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
# A subset of the TF2 saved models can be generated with this tool.
|
||||
py_binary(
|
||||
name = "testdata/generate_saved_models",
|
||||
srcs = ["testdata/generate_saved_models.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
"@absl_py//absl:app",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(b/32673259): add a test to continuously validate these files.
|
||||
filegroup(
|
||||
name = "saved_model_half_plus_two",
|
||||
@ -169,5 +222,7 @@ filegroup(
|
||||
"testdata/half_plus_two/**",
|
||||
"testdata/half_plus_two_v2/**",
|
||||
"testdata/x_plus_y_v2_debuginfo/**",
|
||||
"testdata/CyclicModule/**",
|
||||
"testdata/VarsAndArithmeticObjectGraph/**",
|
||||
]),
|
||||
)
|
||||
|
223
tensorflow/cc/saved_model/bundle_v2.cc
Normal file
223
tensorflow/cc/saved_model/bundle_v2.cc
Normal file
@ -0,0 +1,223 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status ReadSavedModelProto(const string& export_dir,
|
||||
SavedModel* saved_model_proto) {
|
||||
LOG(INFO) << "Reading SavedModel from: " << export_dir;
|
||||
|
||||
const string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
|
||||
return ReadBinaryProto(Env::Default(), saved_model_pb_path,
|
||||
saved_model_proto);
|
||||
}
|
||||
const string saved_model_pbtxt_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||
if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
|
||||
return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
||||
saved_model_proto);
|
||||
}
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Could not find SavedModel .pb or .pbtxt at supplied export "
|
||||
"directory path: " +
|
||||
export_dir);
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
LOG(INFO) << "Reading SavedModel debug info (if present) from: "
|
||||
<< export_dir;
|
||||
|
||||
const string debug_info_pb_path =
|
||||
io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
|
||||
if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
|
||||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
|
||||
*debug_info_proto =
|
||||
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
|
||||
TrackableObjectGraph* object_graph) {
|
||||
Tensor object_graph_tensor;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
|
||||
"SavedModel checkpoint does not contain object graph.");
|
||||
if (object_graph_tensor.dtype() != DT_STRING ||
|
||||
object_graph_tensor.dims() != 0 ||
|
||||
object_graph_tensor.NumElements() != 1) {
|
||||
return Status(
|
||||
error::Code::FAILED_PRECONDITION,
|
||||
"SavedModel checkpoint object graph was not the correct type.");
|
||||
}
|
||||
|
||||
const tstring* object_graph_string = reinterpret_cast<const tstring*>(
|
||||
object_graph_tensor.tensor_data().data());
|
||||
if (!object_graph->ParseFromString(*object_graph_string)) {
|
||||
return Status(
|
||||
error::Code::FAILED_PRECONDITION,
|
||||
"SavedModel checkpoint object graph could not be deserialized.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SavedModelV2Bundle::Load(const std::string& export_dir,
|
||||
SavedModelV2Bundle* const bundle) {
|
||||
SavedModel saved_model_proto;
|
||||
TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
|
||||
|
||||
// Load MetaGraphDef.
|
||||
// In version 2 SavedModels, there is only one MetaGraphDef.
|
||||
if (saved_model_proto.meta_graphs_size() != 1) {
|
||||
return Status(
|
||||
error::Code::INVALID_ARGUMENT,
|
||||
strings::StrCat(
|
||||
"SavedModelV2 should have exactly one MetaGraphDef but actually ",
|
||||
"contains ", saved_model_proto.meta_graphs_size()));
|
||||
}
|
||||
bundle->meta_graph_def_ =
|
||||
std::move(*saved_model_proto.mutable_meta_graphs(0));
|
||||
|
||||
// Load GraphDebugInfo.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
|
||||
|
||||
// Load the variables checkpoint reader.
|
||||
const std::string variables_prefix = io::JoinPath(
|
||||
export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename);
|
||||
bundle->variable_reader_.reset(
|
||||
new BundleReader(Env::Default(), variables_prefix));
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
bundle->variable_reader_->status(),
|
||||
"Unable to load SavedModel variables checkpoint from ", variables_prefix);
|
||||
|
||||
// Deserialize the object graph proto from the tensor bundle.
|
||||
TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
|
||||
bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SavedModelV2Bundle::VisitObjectsToRestore(
|
||||
RestoreObjectsCallback callback) {
|
||||
if (saved_object_graph().nodes_size() == 0 ||
|
||||
trackable_object_graph().nodes_size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
|
||||
// and descend to leaves. Note that the TrackableObjectGraph can have cycles
|
||||
// (as can the SavedObjectGraph).
|
||||
// This is detected and cycle edges are skipped.
|
||||
const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
|
||||
const TrackableObjectGraph::TrackableObject* root_trackable_object =
|
||||
&trackable_object_graph().nodes(0);
|
||||
absl::flat_hash_set<int> trackable_node_ids;
|
||||
return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
|
||||
std::string(), &trackable_node_ids,
|
||||
std::move(callback));
|
||||
}
|
||||
|
||||
Status SavedModelV2Bundle::RecurseObjectsToRestore(
|
||||
const SavedObject* saved_object, int saved_object_node_id,
|
||||
const TrackableObjectGraph::TrackableObject* trackable_object,
|
||||
std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
|
||||
RestoreObjectsCallback callback) {
|
||||
// Callback if any attributes or slot variables.
|
||||
// Note that the root is always excluded from the search (it can never
|
||||
// be a restorable object). This matches some logic on the Python side.
|
||||
if (saved_object_node_id != 0 &&
|
||||
(trackable_object->attributes_size() > 0 ||
|
||||
trackable_object->slot_variables_size() > 0)) {
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
callback(saved_object_node_id, *trackable_object), "Unable to restore ",
|
||||
object_name);
|
||||
}
|
||||
|
||||
for (const auto& trackable_child_ref : trackable_object->children()) {
|
||||
const auto& local_name = trackable_child_ref.local_name();
|
||||
|
||||
// Compute the full child name.
|
||||
std::string child_name;
|
||||
if (object_name.empty()) {
|
||||
child_name = local_name;
|
||||
} else {
|
||||
child_name = strings::StrCat(object_name, ".", local_name);
|
||||
}
|
||||
|
||||
// Descend down the trackable graph.
|
||||
int trackable_child_node_id = trackable_child_ref.node_id();
|
||||
if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
|
||||
// Cycle or duplicate detected - ignore this branch.
|
||||
continue;
|
||||
}
|
||||
if (trackable_child_node_id < 0 ||
|
||||
trackable_child_node_id >= trackable_object_graph().nodes_size()) {
|
||||
return Status(
|
||||
errors::Code::FAILED_PRECONDITION,
|
||||
strings::StrCat("Illegal trackable child node id for ", child_name));
|
||||
}
|
||||
const auto* trackable_child =
|
||||
&trackable_object_graph().nodes(trackable_child_node_id);
|
||||
|
||||
// Descend down the saved object graph.
|
||||
int saved_child_node_id = -1;
|
||||
const SavedObject* saved_child = nullptr;
|
||||
for (const auto& saved_child_ref : saved_object->children()) {
|
||||
if (saved_child_ref.local_name() == local_name) {
|
||||
// Found.
|
||||
saved_child_node_id = saved_child_ref.node_id();
|
||||
if (saved_child_node_id >= 0 &&
|
||||
saved_child_node_id < saved_object_graph().nodes_size()) {
|
||||
saved_child = &saved_object_graph().nodes(saved_child_node_id);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!saved_child) {
|
||||
return Status(
|
||||
errors::Code::FAILED_PRECONDITION,
|
||||
strings::StrCat("Could not find saved object to restore for ",
|
||||
child_name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
|
||||
saved_child, saved_child_node_id, trackable_child, child_name,
|
||||
seen_trackable_node_ids, callback));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
87
tensorflow/cc/saved_model/bundle_v2.h
Normal file
87
tensorflow/cc/saved_model/bundle_v2.h
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Helpers for loading the persistent representation of a SavedModelV2.
|
||||
// Please note that this is depended on by code that does not make use of
|
||||
// the full runtime and its dependencies should be restricted.
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Represents a version 2 SavedModel that is loaded from storage (but not yet
|
||||
/// loaded into an executable in-memory representation).
|
||||
class SavedModelV2Bundle {
|
||||
public:
|
||||
using RestoreObjectsCallback =
|
||||
std::function<Status(int, const TrackableObjectGraph::TrackableObject&)>;
|
||||
|
||||
/// Loads persistent representations for a SavedModelV2 from the specified
|
||||
/// export directory.
|
||||
static Status Load(const std::string& export_dir, SavedModelV2Bundle* bundle);
|
||||
|
||||
/// MetaGraphDef from the loaded SavedModel.
|
||||
MetaGraphDef& meta_graph_def() { return meta_graph_def_; }
|
||||
|
||||
/// SavedObjectGraph from the MetaGraphDef.
|
||||
const SavedObjectGraph& saved_object_graph() {
|
||||
return meta_graph_def().object_graph_def();
|
||||
}
|
||||
|
||||
/// TrackableObjectGraph loaded from the variable_reader() checkpoint.
|
||||
TrackableObjectGraph& trackable_object_graph() {
|
||||
return trackable_object_graph_;
|
||||
}
|
||||
|
||||
/// BundleReader for accessing the variables bundle.
|
||||
BundleReader* variable_reader() { return variable_reader_.get(); }
|
||||
|
||||
/// The GraphDebugInfo (or nullptr if none).
|
||||
GraphDebugInfo* debug_info() { return debug_info_.get(); }
|
||||
|
||||
/// Restores objects, invoking the callback with the node id in the
|
||||
/// saved_object_graph() and the corresponding TrackableObject from the
|
||||
/// trackable_object_graph(). The callback may use the variable_reader() but
|
||||
/// must not modify the underlying saved_object_graph().
|
||||
Status VisitObjectsToRestore(RestoreObjectsCallback callback);
|
||||
|
||||
private:
|
||||
Status RecurseObjectsToRestore(
|
||||
const SavedObject* saved_object, int saved_object_node_id,
|
||||
const TrackableObjectGraph::TrackableObject* trackable_object,
|
||||
std::string object_name,
|
||||
absl::flat_hash_set<int>* seen_trackable_node_ids,
|
||||
RestoreObjectsCallback callback);
|
||||
|
||||
MetaGraphDef meta_graph_def_;
|
||||
TrackableObjectGraph trackable_object_graph_;
|
||||
std::unique_ptr<BundleReader> variable_reader_;
|
||||
std::unique_ptr<GraphDebugInfo> debug_info_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_
|
99
tensorflow/cc/saved_model/bundle_v2_test.cc
Normal file
99
tensorflow/cc/saved_model/bundle_v2_test.cc
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
|
||||
class BundleV2Test : public ::testing::Test {
|
||||
protected:
|
||||
BundleV2Test() {}
|
||||
|
||||
void RestoreVarsAndVerify(SavedModelV2Bundle* bundle,
|
||||
std::vector<std::string> expected_names) {
|
||||
// Collect saved_node_id, full_name, checkpoint_key into a vector.
|
||||
using RestoredVarType = std::tuple<int, std::string, std::string>;
|
||||
std::vector<RestoredVarType> restored_vars;
|
||||
TF_ASSERT_OK(bundle->VisitObjectsToRestore(
|
||||
[&](int saved_node_id,
|
||||
const TrackableObjectGraph::TrackableObject& trackable_object)
|
||||
-> Status {
|
||||
for (const auto& attr : trackable_object.attributes()) {
|
||||
if (attr.name() == "VARIABLE_VALUE") {
|
||||
restored_vars.emplace_back(saved_node_id, attr.full_name(),
|
||||
attr.checkpoint_key());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// Should be one of each var name restored.
|
||||
for (const auto& expected_name : expected_names) {
|
||||
EXPECT_EQ(1, std::count_if(restored_vars.begin(), restored_vars.end(),
|
||||
[&](RestoredVarType t) {
|
||||
return std::get<1>(t) == expected_name;
|
||||
}));
|
||||
}
|
||||
|
||||
for (const auto& restored_var : restored_vars) {
|
||||
// Each restored var should match a SavedObjectGraph node with the same
|
||||
// variable name.
|
||||
const auto& saved_node =
|
||||
bundle->saved_object_graph().nodes(std::get<0>(restored_var));
|
||||
EXPECT_EQ(std::get<1>(restored_var), saved_node.variable().name());
|
||||
|
||||
// And should be able to load it from the tensor_bundle.
|
||||
Tensor value;
|
||||
TF_ASSERT_OK(
|
||||
bundle->variable_reader()->Lookup(std::get<2>(restored_var), &value));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(BundleV2Test, LoadsVarsAndArithmeticObjectGraph) {
|
||||
const string export_dir = io::JoinPath(
|
||||
testing::TensorFlowSrcRoot(), kTestData, "VarsAndArithmeticObjectGraph");
|
||||
|
||||
SavedModelV2Bundle bundle;
|
||||
TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
|
||||
|
||||
// Ensure that there are nodes in the trackable_object_graph.
|
||||
EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0);
|
||||
|
||||
RestoreVarsAndVerify(&bundle, {"variable_x", "variable_y", "child_variable"});
|
||||
}
|
||||
|
||||
TEST_F(BundleV2Test, LoadsCyclicModule) {
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData, "CyclicModule");
|
||||
|
||||
SavedModelV2Bundle bundle;
|
||||
TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
|
||||
|
||||
// Ensure that there are nodes in the trackable_object_graph.
|
||||
EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0);
|
||||
|
||||
RestoreVarsAndVerify(&bundle, {"MyVariable"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -50,6 +50,9 @@ constexpr char kSavedModelVariablesFilename[] = "variables";
|
||||
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
|
||||
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
|
||||
|
||||
// Key in the TensorBundle for the object graph proto.
|
||||
constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
||||
|
235
tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb
vendored
Normal file
235
tensorflow/cc/saved_model/testdata/CyclicModule/debug/saved_model_debug_info.pb
vendored
Normal file
@ -0,0 +1,235 @@
|
||||
|
||||
N/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/absl/app.py
|
||||
m/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py
|
||||
generate_saved_models.pyb
|
||||
:AssignVariableOp/MyVariable@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GW
|
||||
/AssignVariableOp@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GO
|
||||
'Identity@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GL
|
||||
$Identity@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GQ
|
||||
)Identity_1@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GN
|
||||
&Identity_1@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GQ
|
||||
)Identity_2@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
Gj
|
||||
BMergeV2Checkpoints/checkpoint_prefixes@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GV
|
||||
.MergeV2Checkpoints@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GK
|
||||
#NoOp@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
Ga
|
||||
9RestoreV2/shape_and_slices@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
G]
|
||||
5RestoreV2/tensor_names@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GP
|
||||
(RestoreV2@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
Gc
|
||||
;RestoreV2_1/shape_and_slices@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
G_
|
||||
7RestoreV2_1/tensor_names@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GR
|
||||
*RestoreV2_1@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
Gi
|
||||
ASaveV2/MyVariable/Read/ReadVariableOp@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
G[
|
||||
3SaveV2/shape_and_slices@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GW
|
||||
/SaveV2/tensor_names@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GJ
|
||||
"SaveV2@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GR
|
||||
*SaveV2_1/Const@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
G]
|
||||
5SaveV2_1/shape_and_slices@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GY
|
||||
1SaveV2_1/tensor_names@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GL
|
||||
$SaveV2_1@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GY
|
||||
1ShardedFilename/shard@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GS
|
||||
+ShardedFilename@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
G[
|
||||
3ShardedFilename_1/shard@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GU
|
||||
-ShardedFilename_1@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GW
|
||||
/StringJoin/inputs_1@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GN
|
||||
&StringJoin@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GR
|
||||
*file_prefix@__inference__traced_restore_45$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GO
|
||||
'file_prefix@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
GN
|
||||
&num_shards@__inference__traced_save_30$
|
||||
Ö
|
||||
•
|
||||
A
|
||||
ú
|
||||
«
|
||||
G
|
BIN
tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/CyclicModule/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/CyclicModule/variables/variables.index
vendored
Normal file
Binary file not shown.
506
tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/debug/saved_model_debug_info.pb
vendored
Normal file
506
tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/debug/saved_model_debug_info.pb
vendored
Normal file
@ -0,0 +1,506 @@
|
||||
|
||||
j/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py
|
||||
m/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py
|
||||
€/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/signature_serialization.py
|
||||
i/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py
|
||||
N/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/absl/app.py
|
||||
o/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py
|
||||
./generate_saved_models.py
|
||||
k/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py
|
||||
q/usr/local/google/home/laurenzo/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.pyg
|
||||
;AssignVariableOp/variable_x@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6\
|
||||
0AssignVariableOp@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6i
|
||||
=AssignVariableOp_1/variable_y@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6^
|
||||
2AssignVariableOp_1@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6m
|
||||
AAssignVariableOp_2/child_variable@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6^
|
||||
2AssignVariableOp_2@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6T
|
||||
(Identity@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6P
|
||||
$Identity@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6J
|
||||
Identity@__inference_compute_34'
|
||||
T
|
||||
þ
|
||||
0
|
||||
ú
|
||||
«
|
||||
6U
|
||||
)Identity@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6V
|
||||
*Identity_1@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6R
|
||||
&Identity_1@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6V
|
||||
*Identity_2@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6V
|
||||
*Identity_3@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6V
|
||||
*Identity_4@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6n
|
||||
BMergeV2Checkpoints/checkpoint_prefixes@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6Z
|
||||
.MergeV2Checkpoints@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6P
|
||||
$NoOp@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6f
|
||||
:RestoreV2/shape_and_slices@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6b
|
||||
6RestoreV2/tensor_names@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6U
|
||||
)RestoreV2@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6h
|
||||
<RestoreV2_1/shape_and_slices@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6d
|
||||
8RestoreV2_1/tensor_names@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6W
|
||||
+RestoreV2_1@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6q
|
||||
ESaveV2/child_variable/Read/ReadVariableOp@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6_
|
||||
3SaveV2/shape_and_slices@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6[
|
||||
/SaveV2/tensor_names@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6m
|
||||
ASaveV2/variable_x/Read/ReadVariableOp@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6m
|
||||
ASaveV2/variable_y/Read/ReadVariableOp@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6N
|
||||
"SaveV2@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6X
|
||||
,SaveV2_1/Const_1@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6a
|
||||
5SaveV2_1/shape_and_slices@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6]
|
||||
1SaveV2_1/tensor_names@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6P
|
||||
$SaveV2_1@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6]
|
||||
1ShardedFilename/shard@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6W
|
||||
+ShardedFilename@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6_
|
||||
3ShardedFilename_1/shard@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6Y
|
||||
-ShardedFilename_1@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6k
|
||||
?StatefulPartitionedCall/args_2@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6k
|
||||
?StatefulPartitionedCall/args_3@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6k
|
||||
?StatefulPartitionedCall/args_4@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6k
|
||||
?StatefulPartitionedCall/args_5@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6d
|
||||
8StatefulPartitionedCall@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6[
|
||||
/StringJoin/inputs_1@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6R
|
||||
&StringJoin@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6C
|
||||
a@__inference_compute_34'
|
||||
T
|
||||
þ
|
||||
0
|
||||
ú
|
||||
«
|
||||
6N
|
||||
"a@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6y
|
||||
2add/ReadVariableOp/resource@__inference_compute_34C
|
||||
è
|
||||
‰
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿p
|
||||
)add/ReadVariableOp@__inference_compute_34C
|
||||
è
|
||||
‰
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿c
|
||||
add@__inference_compute_34E
|
||||
©
|
||||
’
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿{
|
||||
4add_1/ReadVariableOp/resource@__inference_compute_34C
|
||||
è
|
||||
‰
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿r
|
||||
+add_1/ReadVariableOp@__inference_compute_34C
|
||||
è
|
||||
‰
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿e
|
||||
add_1@__inference_compute_34E
|
||||
©
|
||||
’
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿g
|
||||
add_2/y@__inference_compute_34E
|
||||
©
|
||||
…
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿e
|
||||
add_2@__inference_compute_34E
|
||||
©
|
||||
…
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿C
|
||||
b@__inference_compute_34'
|
||||
T
|
||||
þ
|
||||
0
|
||||
ú
|
||||
«
|
||||
6N
|
||||
"b@__inference_signature_wrapper_45(
|
||||
ƒ
|
||||
€
|
||||
0
|
||||
ú
|
||||
«
|
||||
6W
|
||||
+file_prefix@__inference__traced_restore_101(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6S
|
||||
'file_prefix@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6c
|
||||
mul@__inference_compute_34E
|
||||
°
|
||||
…
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿R
|
||||
&num_shards@__inference__traced_save_80(
|
||||
Ö
|
||||
•
|
||||
0
|
||||
ú
|
||||
«
|
||||
6}
|
||||
6truediv/ReadVariableOp/resource@__inference_compute_34C
|
||||
è
|
||||
‰
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿t
|
||||
-truediv/ReadVariableOp@__inference_compute_34C
|
||||
è
|
||||
‰
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿g
|
||||
truediv@__inference_compute_34E
|
||||
ï
|
||||
’
|
||||
|
||||
Ä
|
||||
ð
|
||||
·
|
||||
Ò
|
||||
†
|
||||
ô
|
||||
¿
|
BIN
tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/variables/variables.index
vendored
Normal file
Binary file not shown.
97
tensorflow/cc/saved_model/testdata/generate_saved_models.py
vendored
Normal file
97
tensorflow/cc/saved_model/testdata/generate_saved_models.py
vendored
Normal file
@ -0,0 +1,97 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Standalone utility to generate some test saved models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
|
||||
|
||||
class VarsAndArithmeticObjectGraph(module.Module):
|
||||
"""Three vars (one in a sub-module) and compute method."""
|
||||
|
||||
def __init__(self):
|
||||
self.x = variables.Variable(1.0, name="variable_x")
|
||||
self.y = variables.Variable(2.0, name="variable_y")
|
||||
self.child = module.Module()
|
||||
self.child.z = variables.Variable(3.0, name="child_variable")
|
||||
self.child.c = ops.convert_to_tensor(5.0)
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec((), dtypes.float32),
|
||||
tensor_spec.TensorSpec((), dtypes.float32)
|
||||
])
|
||||
def compute(self, a, b):
|
||||
return (a + self.x) * (b + self.y) / (self.child.z) + self.child.c
|
||||
|
||||
|
||||
class ReferencesParent(module.Module):
|
||||
|
||||
def __init__(self, parent):
|
||||
super(ReferencesParent, self).__init__()
|
||||
self.parent = parent
|
||||
self.my_variable = variables.Variable(3., name="MyVariable")
|
||||
|
||||
|
||||
# Creates a cyclic object graph.
|
||||
class CyclicModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(CyclicModule, self).__init__()
|
||||
self.child = ReferencesParent(self)
|
||||
|
||||
|
||||
MODULE_CTORS = {
|
||||
"VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph,
|
||||
"CyclicModule": CyclicModule,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
if len(args) != 3:
|
||||
print("Expected: {export_path} {ModuleName}")
|
||||
print("Allowed ModuleNames:", MODULE_CTORS.keys())
|
||||
return 1
|
||||
|
||||
_, export_path, module_name = args
|
||||
module_ctor = MODULE_CTORS.get(module_name)
|
||||
if not module_ctor:
|
||||
print("Expected ModuleName to be one of:", MODULE_CTORS.keys())
|
||||
return 2
|
||||
os.makedirs(export_path)
|
||||
|
||||
tf_module = module_ctor()
|
||||
options = save_options.SaveOptions(save_debug_info=True)
|
||||
saved_model.save(tf_module, export_path, options=options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
v2_compat.enable_v2_behavior()
|
||||
app.run(main)
|
@ -83,28 +83,22 @@ string ExperimentalConvertSavedModelToMlir(
|
||||
const string &exported_names_str,
|
||||
bool show_debug_info,
|
||||
TF_Status* status) {
|
||||
// Load the saved model into a SavedModelBundle.
|
||||
// Load the saved model into a SavedModelV2Bundle.
|
||||
|
||||
// TODO(silvasean): Add support for tags, if needed.
|
||||
// The default "serve" tag seems to be enough.
|
||||
std::unordered_set<string> tags;
|
||||
tags.insert("serve");
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
tensorflow::SavedModelBundle bundle;
|
||||
auto load_status = LoadSavedModel(session_options, run_options,
|
||||
saved_model_path, tags, &bundle);
|
||||
tensorflow::SavedModelV2Bundle bundle;
|
||||
auto load_status = tensorflow::SavedModelV2Bundle::Load(
|
||||
saved_model_path, &bundle);
|
||||
if (!load_status.ok()) {
|
||||
Set_TF_Status_from_Status(status, load_status);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
// Convert the SavedModelV2Bundle to an MLIR module.
|
||||
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = ConvertSavedModelToMlir(bundle, &context,
|
||||
auto module_or = ConvertSavedModelToMlir(&bundle, &context,
|
||||
absl::Span<std::string>(exported_names));
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
|
@ -308,7 +308,7 @@ cc_library(
|
||||
":mlir_roundtrip_flags",
|
||||
":tensorflow",
|
||||
":tensorflow_passes",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/compiler/jit:shape_inference_helpers",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
|
||||
|
@ -1815,7 +1815,7 @@ class SavedModelImporter : public ImporterBase {
|
||||
// Main entry point: converts all functions in the given meta graph to an MLIR
|
||||
// Module.
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context,
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes);
|
||||
|
||||
private:
|
||||
@ -2015,29 +2015,45 @@ llvm::StringRef ObjectNames::SaveString(const std::string& s) const {
|
||||
return llvm::StringRef(*saved_strings_.insert(s).first);
|
||||
}
|
||||
|
||||
StatusOr<Tensor> GetTensorFromSession(Session* session, std::string name) {
|
||||
std::vector<Tensor> outputs;
|
||||
TF_RETURN_IF_ERROR(session->Run(/*inputs=*/{}, /*output_tensor_names=*/{name},
|
||||
/*target_node_names=*/{}, &outputs));
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
// Variable ops return resource types, but we want to read their contents.
|
||||
// We need to find a "ReadVariableOp" that reads a given variable to get out a
|
||||
// tensor value. These seem to always be present in the GraphDef's main graph.
|
||||
// TODO(silvasean): Find a better way to do this.
|
||||
StatusOr<Tensor> ReadVariableFromSession(const SavedModelBundle& saved_model,
|
||||
std::string variable_name) {
|
||||
const GraphDef& graph_def = saved_model.meta_graph_def.graph_def();
|
||||
// TODO(silvasean): Don't do linear search.
|
||||
for (const NodeDef& node : graph_def.node()) {
|
||||
if (node.op() == "ReadVariableOp" && node.input_size() == 1 &&
|
||||
node.input(0) == variable_name) {
|
||||
return GetTensorFromSession(saved_model.session.get(), node.name());
|
||||
// Extracts a TensorProto for a Const op from a GraphDef, given an op_name.
|
||||
// Returns nullptr on not found or other mismatch.
|
||||
// This returns a pointer to the actual node within the graph_def so as to
|
||||
// avoid expensive copies.
|
||||
const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def,
|
||||
const std::string& op_name) {
|
||||
const NodeDef* match_node = nullptr;
|
||||
for (const auto& node : graph_def.node()) {
|
||||
if (node.name() == op_name) {
|
||||
match_node = &node;
|
||||
}
|
||||
}
|
||||
return errors::InvalidArgument("Could not find ReadVariableOp reading '",
|
||||
variable_name, "'");
|
||||
|
||||
if (!match_node) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto value_it = match_node->attr().find("value");
|
||||
if (value_it == match_node->attr().end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!value_it->second.has_tensor()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return &value_it->second.tensor();
|
||||
}
|
||||
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
FindSerializedTensorInTrackable(
|
||||
const TrackableObjectGraph::TrackableObject& trackable_object,
|
||||
StringPiece name) {
|
||||
for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
|
||||
if (maybe_serialized_tensor.name() == name) {
|
||||
return &maybe_serialized_tensor;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph,
|
||||
@ -2237,9 +2253,22 @@ Status CreateSavedModelIR(
|
||||
const ObjectNames& object_names, mlir::ModuleOp module,
|
||||
const SavedObjectGraph& object_graph,
|
||||
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
|
||||
const SavedModelBundle& saved_model) {
|
||||
SavedModelV2Bundle* saved_model) {
|
||||
mlir::OpBuilder builder(module.getBodyRegion());
|
||||
mlir::SymbolTable symbol_table(module);
|
||||
|
||||
// Create a side data-structure, indexed by the object_graph node_id to
|
||||
// a TrackableObject that is restorable.
|
||||
absl::flat_hash_map<int, const TrackableObjectGraph::TrackableObject*>
|
||||
restored_objects;
|
||||
TF_RETURN_IF_ERROR(saved_model->VisitObjectsToRestore(
|
||||
[&](int saved_node_id,
|
||||
const TrackableObjectGraph::TrackableObject& trackable_object) {
|
||||
restored_objects.insert(
|
||||
std::make_pair(saved_node_id, &trackable_object));
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) {
|
||||
const SavedObject& object = object_graph.nodes(node_id);
|
||||
// For correctness, we cannot import functions that don't have exported
|
||||
@ -2339,8 +2368,27 @@ Status CreateSavedModelIR(
|
||||
}
|
||||
} else if (object.kind_case() == SavedObject::kVariable) {
|
||||
const SavedVariable& variable = object.variable();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Tensor value, ReadVariableFromSession(saved_model, variable.name()));
|
||||
// Find the trackable in the side data structure.
|
||||
auto variable_trackable_it = restored_objects.find(node_id);
|
||||
if (variable_trackable_it == restored_objects.end()) {
|
||||
return errors::FailedPrecondition("Could not restore saved variable: ",
|
||||
variable.name());
|
||||
}
|
||||
const auto* serialized_tensor_attr = FindSerializedTensorInTrackable(
|
||||
*variable_trackable_it->second, "VARIABLE_VALUE");
|
||||
if (!serialized_tensor_attr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Could not find serialized tensor for saved variable: ",
|
||||
variable.name());
|
||||
}
|
||||
const auto& checkpoint_key = serialized_tensor_attr->checkpoint_key();
|
||||
|
||||
// Load it from the reader.
|
||||
Tensor value;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
saved_model->variable_reader()->Lookup(checkpoint_key, &value),
|
||||
"Could not read checkpoint key from variables bundle: ",
|
||||
checkpoint_key);
|
||||
TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
|
||||
// A variable can have a partially known type, such as tensor<?x27x?xf32>,
|
||||
// even if the initializer is a specific static shape.
|
||||
@ -2358,10 +2406,15 @@ Status CreateSavedModelIR(
|
||||
builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
|
||||
} else if (object.kind_case() == SavedObject::kConstant) {
|
||||
const SavedConstant& constant = object.constant();
|
||||
TF_ASSIGN_OR_RETURN(Tensor value,
|
||||
GetTensorFromSession(saved_model.session.get(),
|
||||
constant.operation()));
|
||||
TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder));
|
||||
const TensorProto* value = ExtractConstTensorFromGraph(
|
||||
saved_model->meta_graph_def().graph_def(), constant.operation());
|
||||
if (!value) {
|
||||
return errors::FailedPrecondition(
|
||||
"Unable to find const node referenced in object graph: ",
|
||||
constant.operation());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto value_attr,
|
||||
ConvertTensorProto(*value, &builder));
|
||||
auto op = builder.create<mlir::tf_saved_model::GlobalTensorOp>(
|
||||
builder.getUnknownLoc(),
|
||||
builder.getStringAttr(object_names.GetSymbolTableName(node_id)),
|
||||
@ -2379,18 +2432,18 @@ Status CreateSavedModelIR(
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert(
|
||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context,
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||
GraphDebugInfo dummy_debug_info;
|
||||
const GraphDebugInfo& debug_info =
|
||||
saved_model.debug_info ? *saved_model.debug_info : dummy_debug_info;
|
||||
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
||||
|
||||
GraphImportConfig specs;
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
|
||||
const auto& graphdef = saved_model.meta_graph_def.graph_def();
|
||||
const auto& graphdef = saved_model->meta_graph_def().graph_def();
|
||||
GraphConstructorOptions options;
|
||||
options.allow_internal_ops = true;
|
||||
options.add_default_attributes = add_default_attributes;
|
||||
@ -2413,11 +2466,11 @@ StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert(
|
||||
TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
|
||||
}
|
||||
|
||||
if (!saved_model.meta_graph_def.has_object_graph_def()) {
|
||||
if (!saved_model->meta_graph_def().has_object_graph_def()) {
|
||||
return errors::InvalidArgument(
|
||||
"SavedModel does not have an object graph. Please use TF2.");
|
||||
}
|
||||
auto& object_graph = saved_model.meta_graph_def.object_graph_def();
|
||||
auto& object_graph = saved_model->meta_graph_def().object_graph_def();
|
||||
ObjectNames object_names(object_graph, exported_names);
|
||||
|
||||
// Clean up a couple func's that always seem to be present when importing a
|
||||
@ -2484,7 +2537,7 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context,
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||
return SavedModelImporter::Convert(saved_model, context, exported_names,
|
||||
add_default_attributes);
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -47,7 +47,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
// Given a SavedModel, returns a MLIR module containing the functions, expressed
|
||||
// with tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context,
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes = true);
|
||||
|
||||
// Serialize a MLIR module to a string.
|
||||
|
@ -109,20 +109,16 @@ mlir::OwningModuleRef SavedModelToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
tensorflow::SavedModelBundle bundle;
|
||||
auto load_status = LoadSavedModel(
|
||||
session_options, run_options,
|
||||
std::string(saved_model_dir.data(), saved_model_dir.length()), tags,
|
||||
&bundle);
|
||||
tensorflow::SavedModelV2Bundle bundle;
|
||||
auto load_status = tensorflow::SavedModelV2Bundle::Load(
|
||||
std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
|
||||
if (!load_status.ok()) {
|
||||
LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
|
||||
<< "': " << load_status;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto module_or = ConvertSavedModelToMlir(bundle, context, exported_names);
|
||||
auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
|
@ -626,6 +626,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:session_options",
|
||||
"@llvm//:support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user