Adding convenience functions for calling restore ops. These are needed for implementing the SavedModel C API.
PiperOrigin-RevId: 320114439 Change-Id: I707e9e228e509ccf37202fa7862c92e7ee78f46d
This commit is contained in:
parent
4f34a259d2
commit
087d3651ba
@ -14,6 +14,27 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restore_ops",
|
||||
srcs = [
|
||||
"restore_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"restore_ops.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_ops",
|
||||
srcs = [
|
||||
@ -36,6 +57,34 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "restore_ops_test",
|
||||
srcs = [
|
||||
"restore_ops_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
":restore_ops",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:test_utils",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "variable_ops_test",
|
||||
srcs = [
|
||||
|
111
tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc
Normal file
111
tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
namespace {
|
||||
|
||||
// Creates a scalar string tensorhandle containing a single string `s`
|
||||
Status CreateStringScalarTensorHandle(ImmediateExecutionContext* ctx,
|
||||
const std::string& s,
|
||||
ImmediateTensorHandlePtr* out) {
|
||||
AbstractTensorPtr tensor(ctx->CreateStringScalar(s));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create scalar string tensor for checkpoint restore");
|
||||
}
|
||||
|
||||
out->reset(ctx->CreateLocalHandle(tensor.get()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Creates a Rank 1 string tensorhandle containing a single string `s`
|
||||
Status CreateStringVectorTensorHandle(ImmediateExecutionContext* ctx,
|
||||
const std::string& s,
|
||||
ImmediateTensorHandlePtr* out) {
|
||||
int64 flat_shape[] = {1};
|
||||
AbstractTensorPtr tensor(ctx->CreateTensor(DT_STRING, flat_shape));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create vector string tensor for checkpoint restore");
|
||||
}
|
||||
// Use placement new to construct the string, since we don't have
|
||||
// access to Tensor::flat. This is conceptually equivalent to:
|
||||
// tensor.flat<tstring>()(0) = s
|
||||
new (tensor->Data()) tstring(s);
|
||||
|
||||
out->reset(ctx->CreateLocalHandle(tensor.get()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
|
||||
const std::string& checkpoint_key, DataType dtype,
|
||||
ImmediateTensorHandlePtr* out) {
|
||||
// Create the EagerOp
|
||||
ImmediateOpPtr restore_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(restore_op->Reset("RestoreV2", "/cpu:0"));
|
||||
TF_RETURN_IF_ERROR(restore_op->SetAttrTypeList("dtypes", &dtype, 1));
|
||||
|
||||
ImmediateTensorHandlePtr prefix_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateStringScalarTensorHandle(ctx, prefix, &prefix_handle));
|
||||
|
||||
ImmediateTensorHandlePtr names_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateStringVectorTensorHandle(ctx, checkpoint_key, &names_handle));
|
||||
|
||||
// Note that empty string is the slice spec used for a non-partitioned
|
||||
// ResourceVariable:
|
||||
// https://github.com/tensorflow/tensorflow/blob/06ff30f7ea35098cb68a231a9eb7ff3ff4be4e1e/tensorflow/python/training/saving/saveable_object_util.py#L194
|
||||
ImmediateTensorHandlePtr shapes_and_slices_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateStringVectorTensorHandle(ctx, "", &shapes_and_slices_handle));
|
||||
|
||||
TF_RETURN_IF_ERROR(restore_op->AddInput(prefix_handle.get()));
|
||||
TF_RETURN_IF_ERROR(restore_op->AddInput(names_handle.get()));
|
||||
TF_RETURN_IF_ERROR(restore_op->AddInput(shapes_and_slices_handle.get()));
|
||||
|
||||
AbstractTensorHandle* restored_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(restore_op->Execute(
|
||||
absl::MakeSpan(&restored_handle, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_restored_handle(restored_handle);
|
||||
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
|
||||
owned_restored_handle.get())) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
out->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
owned_restored_handle.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
40
tensorflow/c/experimental/saved_model/core/ops/restore_ops.h
Normal file
40
tensorflow/c/experimental/saved_model/core/ops/restore_ops.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// TODO(bmzhao): Add a function to restore multiple tensors in one call.
|
||||
|
||||
// Restores a single non-partioned tensorhandle of dtype `dtype`, using
|
||||
// checkpoint at `prefix`, with a value stored in `checkpoint_key`.
|
||||
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
|
||||
const std::string& checkpoint_key, DataType dtype,
|
||||
ImmediateTensorHandlePtr* out);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
|
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::string CheckpointPrefix(StringPiece saved_model_dir) {
|
||||
return io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
|
||||
saved_model_dir, kSavedModelVariablesDirectory,
|
||||
kSavedModelVariablesFilename);
|
||||
}
|
||||
|
||||
class RestoreOpsTest : public ::testing::Test {
|
||||
public:
|
||||
RestoreOpsTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// One way of obtaining the checkpointa checkpoint's tensor names is:
|
||||
// bazel run //tensorflow/python/tools:inspect_checkpoint -- --all_tensors
|
||||
// --file_name="$CKPT_PREFIX".
|
||||
// Here are the values for VarsAndArithmeticObjectGraph:
|
||||
// tensor: child/z/.ATTRIBUTES/VARIABLE_VALUE (float32) []
|
||||
// 3.0
|
||||
// tensor: x/.ATTRIBUTES/VARIABLE_VALUE (float32) []
|
||||
// 1.0
|
||||
// tensor: y/.ATTRIBUTES/VARIABLE_VALUE (float32) []
|
||||
// 2.0
|
||||
|
||||
TEST_F(RestoreOpsTest, RestoreSuccessful) {
|
||||
ImmediateTensorHandlePtr x_handle;
|
||||
TF_EXPECT_OK(internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle));
|
||||
AbstractTensorPtr x = testing::TensorHandleToTensor(x_handle.get());
|
||||
EXPECT_EQ(x->Type(), DT_FLOAT);
|
||||
EXPECT_EQ(x->NumElements(), 1);
|
||||
EXPECT_EQ(x->NumDims(), 0);
|
||||
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(x->Data()), 1.0f);
|
||||
|
||||
ImmediateTensorHandlePtr y_handle;
|
||||
TF_EXPECT_OK(internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"y/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &y_handle));
|
||||
AbstractTensorPtr y = testing::TensorHandleToTensor(y_handle.get());
|
||||
EXPECT_EQ(y->Type(), DT_FLOAT);
|
||||
EXPECT_EQ(y->NumElements(), 1);
|
||||
EXPECT_EQ(y->NumDims(), 0);
|
||||
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(y->Data()), 2.0f);
|
||||
|
||||
ImmediateTensorHandlePtr z_handle;
|
||||
TF_EXPECT_OK(internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"child/z/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &z_handle));
|
||||
AbstractTensorPtr z = testing::TensorHandleToTensor(z_handle.get());
|
||||
EXPECT_EQ(z->Type(), DT_FLOAT);
|
||||
EXPECT_EQ(z->NumElements(), 1);
|
||||
EXPECT_EQ(z->NumDims(), 0);
|
||||
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(z->Data()), 3.0f);
|
||||
}
|
||||
|
||||
TEST_F(RestoreOpsTest, BadCheckpointPrefixShouldFail) {
|
||||
ImmediateTensorHandlePtr x_handle;
|
||||
Status status = internal::SingleRestore(
|
||||
context(), CheckpointPrefix("unknown_bad_checkpoint_prefix"),
|
||||
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle);
|
||||
EXPECT_FALSE(status.ok()) << status.error_message();
|
||||
}
|
||||
|
||||
TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) {
|
||||
ImmediateTensorHandlePtr x_handle;
|
||||
Status status = internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"bad_checkpoint_key", DT_FLOAT, &x_handle);
|
||||
EXPECT_FALSE(status.ok()) << status.error_message();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -139,5 +139,13 @@ void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle) {
|
||||
Status status;
|
||||
AbstractTensorPtr tensor(handle->Resolve(&status));
|
||||
CHECK(status.ok()) << status.error_message();
|
||||
CHECK_NE(tensor.get(), nullptr);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
@ -69,6 +69,10 @@ void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b);
|
||||
|
||||
// Converts a TensorHandle to a Tensor, and dies if unsuccessful. This should
|
||||
// only be used for testing purposes.
|
||||
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user