Support savedmodels that reference assets.
PiperOrigin-RevId: 332523998 Change-Id: Ia0f7cd75f79020c1b385ee6c02323c9d431ff86d
This commit is contained in:
parent
d58c96946b
commit
00a5ef689b
@ -619,6 +619,8 @@ tf_cuda_library(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -17,12 +17,16 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::tstring;
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
float data[] = {value};
|
||||
@ -36,6 +40,19 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
|
||||
const tensorflow::tstring& value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
|
||||
tstring* data = static_cast<tstring*>(TF_TensorData(t));
|
||||
*data = value;
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
|
||||
int data[] = {value};
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -28,6 +29,10 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
|
||||
// Return a tensor handle containing a bool scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
|
||||
|
||||
// Return a tensor handle containing a tstring scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
|
||||
const tensorflow::tstring& value);
|
||||
|
||||
// Return a tensor handle containing a 2x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
|
@ -62,6 +62,7 @@ cc_library(
|
||||
":function_metadata",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
@ -69,6 +70,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -138,7 +140,6 @@ cc_library(
|
||||
":saved_model_api",
|
||||
":saved_model_utils",
|
||||
":signature_def_function",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
|
||||
@ -148,10 +149,10 @@ cc_library(
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
@ -8,6 +8,25 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "asset",
|
||||
srcs = [
|
||||
"asset.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"asset.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Asset::Asset(ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)) {}
|
||||
|
||||
Status Asset::Create(ImmediateExecutionContext* ctx,
|
||||
const std::string& saved_model_dir,
|
||||
const std::string& asset_filename,
|
||||
std::unique_ptr<Asset>* output) {
|
||||
std::string abs_path =
|
||||
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
|
||||
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create scalar string tensor for Asset at path ", abs_path);
|
||||
}
|
||||
|
||||
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
|
||||
output->reset(new Asset(std::move(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Asset : public TensorHandleConvertible {
|
||||
public:
|
||||
static Status Create(ImmediateExecutionContext* ctx,
|
||||
const std::string& saved_model_dir,
|
||||
const std::string& asset_filename,
|
||||
std::unique_ptr<Asset>* output);
|
||||
|
||||
// Asset is movable, but not copyable.
|
||||
Asset(Asset&& other) = default;
|
||||
Asset& operator=(Asset&& other) = default;
|
||||
|
||||
~Asset() override = default;
|
||||
|
||||
private:
|
||||
explicit Asset(ImmediateTensorHandlePtr handle);
|
||||
Asset(const Asset&) = delete;
|
||||
Asset& operator=(const Asset&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
@ -100,6 +100,20 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
|
||||
const std::string& saved_model_dir,
|
||||
absl::Span<const AssetFileDef> assets,
|
||||
std::unique_ptr<Asset>* output) {
|
||||
int asset_index = asset.asset_file_def_index();
|
||||
if (asset_index >= assets.size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SavedAsset contained asset index ", asset_index,
|
||||
" but AssetFileDef only contains ", assets.size(), " # of assets");
|
||||
}
|
||||
const std::string& asset_filename = assets[asset_index].filename();
|
||||
return Asset::Create(ctx, saved_model_dir, asset_filename, output);
|
||||
}
|
||||
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
|
@ -22,7 +22,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
@ -31,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
@ -52,6 +55,11 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const SavedVariable& variable,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
|
||||
const std::string& saved_model_dir,
|
||||
absl::Span<const AssetFileDef> assets,
|
||||
std::unique_ptr<Asset>* output);
|
||||
|
||||
// Creates a TFConcreteFunction from a SavedConcreteFunction.
|
||||
Status LoadTFConcreteFunction(
|
||||
const SavedConcreteFunction& saved_concrete_function,
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -108,12 +109,17 @@ Status ConstantFromSavedConstant(
|
||||
// SavedResources. These are returned via the `out` parameter.
|
||||
Status ReviveObjects(
|
||||
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
|
||||
revived_objects) {
|
||||
// This is needed to restore "Constant" nodes by looking up their
|
||||
// "Value" attribute.
|
||||
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
|
||||
|
||||
// These are needed for creating "Assets", by looking up their filenames.
|
||||
std::vector<AssetFileDef> assets;
|
||||
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets));
|
||||
|
||||
// Iterate through all the saved objects, restoring objects as we go.
|
||||
// We don't recreate functions until all other objects have been created.
|
||||
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||
@ -129,12 +135,10 @@ Status ReviveObjects(
|
||||
node_attr_map, &constant));
|
||||
(*revived_objects)[i] = std::move(constant);
|
||||
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
|
||||
// the full path to the asset file:
|
||||
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
|
||||
// and storing it as a string tensor:
|
||||
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
|
||||
return errors::Unimplemented("SavedAsset loading is not implemented yet");
|
||||
std::unique_ptr<Asset> asset;
|
||||
TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(),
|
||||
directory, assets, &asset));
|
||||
(*revived_objects)[i] = std::move(asset);
|
||||
} else if (node.kind_case() == SavedObject::kResource) {
|
||||
// TODO(bmzhao): Figure out how resource loading works and implement it
|
||||
return errors::Unimplemented(
|
||||
@ -352,8 +356,8 @@ Status TFSavedModelAPI::Load(
|
||||
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
|
||||
|
||||
RevivedObjectMap revived_objects;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
|
||||
TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory,
|
||||
&revived_objects));
|
||||
|
||||
// TODO(bmzhao): When we later add support for loading resources, we need to
|
||||
// handle the case where materializing a function's captures requires invoking
|
||||
|
@ -27,9 +27,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::tstring;
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
const char* kServeTag[] = {"serve"};
|
||||
|
||||
@ -137,6 +140,60 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("AssetModule");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TF_ConcreteFunction* read_file_fn =
|
||||
TF_GetSavedModelConcreteFunction(saved_model, "read_file", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* read_file_op =
|
||||
TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::tstring* output_value =
|
||||
static_cast<tensorflow::tstring*>(TF_TensorData(result));
|
||||
EXPECT_EQ(std::string(*output_value), "TEST ASSET FILE CONTENTS\n");
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
|
||||
TFE_DeleteOp(read_file_op);
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
|
||||
::testing::Bool());
|
||||
|
||||
|
@ -209,9 +209,13 @@ tf_cc_test(
|
||||
py_binary(
|
||||
name = "testdata/generate_saved_models",
|
||||
srcs = ["testdata/generate_saved_models.py"],
|
||||
data = [
|
||||
":saved_model_asset_data",
|
||||
],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
@ -221,6 +225,7 @@ py_binary(
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"@absl_py//absl:app",
|
||||
],
|
||||
)
|
||||
@ -229,6 +234,7 @@ py_binary(
|
||||
filegroup(
|
||||
name = "saved_model_test_files",
|
||||
srcs = glob([
|
||||
"testdata/AssetModule/**",
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
"testdata/half_plus_two_main_op/**",
|
||||
"testdata/half_plus_two/**",
|
||||
@ -245,6 +251,13 @@ alias(
|
||||
actual = ":saved_model_test_files",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "saved_model_asset_data",
|
||||
srcs = [
|
||||
"testdata/test_asset.txt",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
glob([
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
|
1
tensorflow/cc/saved_model/testdata/AssetModule/assets/test_asset.txt
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/AssetModule/assets/test_asset.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
TEST ASSET FILE CONTENTS
|
BIN
tensorflow/cc/saved_model/testdata/AssetModule/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/AssetModule/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.index
vendored
Normal file
Binary file not shown.
@ -29,9 +29,12 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
class VarsAndArithmeticObjectGraph(module.Module):
|
||||
@ -68,9 +71,21 @@ class CyclicModule(module.Module):
|
||||
self.child = ReferencesParent(self)
|
||||
|
||||
|
||||
class AssetModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self.asset = tracking.Asset(
|
||||
test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt"))
|
||||
|
||||
@def_function.function(input_signature=[])
|
||||
def read_file(self):
|
||||
return io_ops.read_file(self.asset)
|
||||
|
||||
|
||||
MODULE_CTORS = {
|
||||
"VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph,
|
||||
"CyclicModule": CyclicModule,
|
||||
"AssetModule": AssetModule,
|
||||
}
|
||||
|
||||
|
||||
|
1
tensorflow/cc/saved_model/testdata/test_asset.txt
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/test_asset.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
TEST ASSET FILE CONTENTS
|
Loading…
Reference in New Issue
Block a user