Support savedmodels that reference assets.

PiperOrigin-RevId: 332523998
Change-Id: Ia0f7cd75f79020c1b385ee6c02323c9d431ff86d
This commit is contained in:
Brian Zhao 2020-09-18 14:09:38 -07:00 committed by TensorFlower Gardener
parent d58c96946b
commit 00a5ef689b
18 changed files with 266 additions and 10 deletions

View File

@ -619,6 +619,8 @@ tf_cuda_library(
":c_api",
":c_api_experimental",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -17,12 +17,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
using tensorflow::string;
using tensorflow::tstring;
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
float data[] = {value};
@ -36,6 +40,19 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
tstring* data = static_cast<tstring*>(TF_TensorData(t));
*data = value;
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
int data[] = {value};
TF_Status* status = TF_NewStatus();

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -28,6 +29,10 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
// Return a tensor handle containing a bool scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
// Return a tensor handle containing a tstring scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value);
// Return a tensor handle containing a 2x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);

View File

@ -62,6 +62,7 @@ cc_library(
":function_metadata",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
@ -69,6 +70,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
@ -138,7 +140,6 @@ cc_library(
":saved_model_api",
":saved_model_utils",
":signature_def_function",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
@ -148,10 +149,10 @@ cc_library(
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",

View File

@ -8,6 +8,25 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "asset",
srcs = [
"asset.cc",
],
hdrs = [
"asset.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "constant",
srcs = [

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/path.h"
namespace tensorflow {
Asset::Asset(ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)) {}
Status Asset::Create(ImmediateExecutionContext* ctx,
const std::string& saved_model_dir,
const std::string& asset_filename,
std::unique_ptr<Asset>* output) {
std::string abs_path =
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create scalar string tensor for Asset at path ", abs_path);
}
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
output->reset(new Asset(std::move(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
class Asset : public TensorHandleConvertible {
public:
static Status Create(ImmediateExecutionContext* ctx,
const std::string& saved_model_dir,
const std::string& asset_filename,
std::unique_ptr<Asset>* output);
// Asset is movable, but not copyable.
Asset(Asset&& other) = default;
Asset& operator=(Asset&& other) = default;
~Asset() override = default;
private:
explicit Asset(ImmediateTensorHandlePtr handle);
Asset(const Asset&) = delete;
Asset& operator=(const Asset&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_

View File

@ -100,6 +100,20 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
} // namespace
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
const std::string& saved_model_dir,
absl::Span<const AssetFileDef> assets,
std::unique_ptr<Asset>* output) {
int asset_index = asset.asset_file_def_index();
if (asset_index >= assets.size()) {
return errors::FailedPrecondition(
"SavedAsset contained asset index ", asset_index,
" but AssetFileDef only contains ", assets.size(), " # of assets");
}
const std::string& asset_filename = assets[asset_index].filename();
return Asset::Create(ctx, saved_model_dir, asset_filename, output);
}
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output) {

View File

@ -22,7 +22,9 @@ limitations under the License.
#include <memory>
#include <unordered_map>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
@ -31,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
@ -52,6 +55,11 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output);
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
const std::string& saved_model_dir,
absl::Span<const AssetFileDef> assets,
std::unique_ptr<Asset>* output);
// Creates a TFConcreteFunction from a SavedConcreteFunction.
Status LoadTFConcreteFunction(
const SavedConcreteFunction& saved_concrete_function,

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -108,12 +109,17 @@ Status ConstantFromSavedConstant(
// SavedResources. These are returned via the `out` parameter.
Status ReviveObjects(
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
const std::string& directory,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
revived_objects) {
// This is needed to restore "Constant" nodes by looking up their
// "Value" attribute.
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
// These are needed for creating "Assets", by looking up their filenames.
std::vector<AssetFileDef> assets;
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets));
// Iterate through all the saved objects, restoring objects as we go.
// We don't recreate functions until all other objects have been created.
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
@ -129,12 +135,10 @@ Status ReviveObjects(
node_attr_map, &constant));
(*revived_objects)[i] = std::move(constant);
} else if (node.kind_case() == SavedObject::kAsset) {
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
// the full path to the asset file:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
// and storing it as a string tensor:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
return errors::Unimplemented("SavedAsset loading is not implemented yet");
std::unique_ptr<Asset> asset;
TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(),
directory, assets, &asset));
(*revived_objects)[i] = std::move(asset);
} else if (node.kind_case() == SavedObject::kResource) {
// TODO(bmzhao): Figure out how resource loading works and implement it
return errors::Unimplemented(
@ -352,8 +356,8 @@ Status TFSavedModelAPI::Load(
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
RevivedObjectMap revived_objects;
TF_RETURN_IF_ERROR(
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory,
&revived_objects));
// TODO(bmzhao): When we later add support for loading resources, we need to
// handle the case where materializing a function's captures requires invoking

View File

@ -27,9 +27,12 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
namespace {
using tensorflow::tstring;
constexpr char kTestData[] = "cc/saved_model/testdata";
const char* kServeTag[] = {"serve"};
@ -137,6 +140,60 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = SavedModelPath("AssetModule");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_ConcreteFunction* read_file_fn =
TF_GetSavedModelConcreteFunction(saved_model, "read_file", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* read_file_op =
TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
tensorflow::tstring* output_value =
static_cast<tensorflow::tstring*>(TF_TensorData(result));
EXPECT_EQ(std::string(*output_value), "TEST ASSET FILE CONTENTS\n");
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
TFE_DeleteOp(read_file_op);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
::testing::Bool());

View File

@ -209,9 +209,13 @@ tf_cc_test(
py_binary(
name = "testdata/generate_saved_models",
srcs = ["testdata/generate_saved_models.py"],
data = [
":saved_model_asset_data",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_spec",
@ -221,6 +225,7 @@ py_binary(
"//tensorflow/python/module",
"//tensorflow/python/saved_model",
"//tensorflow/python/saved_model:save_options",
"//tensorflow/python/training/tracking",
"@absl_py//absl:app",
],
)
@ -229,6 +234,7 @@ py_binary(
filegroup(
name = "saved_model_test_files",
srcs = glob([
"testdata/AssetModule/**",
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
@ -245,6 +251,13 @@ alias(
actual = ":saved_model_test_files",
)
filegroup(
name = "saved_model_asset_data",
srcs = [
"testdata/test_asset.txt",
],
)
exports_files(
glob([
"testdata/half_plus_two_pbtxt/**",

View File

@ -0,0 +1 @@
TEST ASSET FILE CONTENTS

Binary file not shown.

View File

@ -29,9 +29,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.tracking import tracking
class VarsAndArithmeticObjectGraph(module.Module):
@ -68,9 +71,21 @@ class CyclicModule(module.Module):
self.child = ReferencesParent(self)
class AssetModule(module.Module):
def __init__(self):
self.asset = tracking.Asset(
test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt"))
@def_function.function(input_signature=[])
def read_file(self):
return io_ops.read_file(self.asset)
MODULE_CTORS = {
"VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph,
"CyclicModule": CyclicModule,
"AssetModule": AssetModule,
}

View File

@ -0,0 +1 @@
TEST ASSET FILE CONTENTS