Adding convenience functions for calling restore ops. These are needed for implementing the SavedModel C API.

PiperOrigin-RevId: 320114439
Change-Id: I707e9e228e509ccf37202fa7862c92e7ee78f46d
This commit is contained in:
Brian Zhao 2020-07-07 20:30:42 -07:00 committed by TensorFlower Gardener
parent 4f34a259d2
commit 087d3651ba
6 changed files with 323 additions and 0 deletions

View File

@ -14,6 +14,27 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "restore_ops",
srcs = [
"restore_ops.cc",
],
hdrs = [
"restore_ops.h",
],
deps = [
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library(
name = "variable_ops",
srcs = [
@ -36,6 +57,34 @@ cc_library(
],
)
tf_cc_test(
name = "restore_ops_test",
srcs = [
"restore_ops_test.cc",
],
data = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
deps = [
":restore_ops",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:test_utils",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test(
name = "variable_ops_test",
srcs = [

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
namespace {
// Creates a scalar string tensorhandle containing a single string `s`
Status CreateStringScalarTensorHandle(ImmediateExecutionContext* ctx,
const std::string& s,
ImmediateTensorHandlePtr* out) {
AbstractTensorPtr tensor(ctx->CreateStringScalar(s));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create scalar string tensor for checkpoint restore");
}
out->reset(ctx->CreateLocalHandle(tensor.get()));
return Status();
}
// Creates a Rank 1 string tensorhandle containing a single string `s`
Status CreateStringVectorTensorHandle(ImmediateExecutionContext* ctx,
const std::string& s,
ImmediateTensorHandlePtr* out) {
int64 flat_shape[] = {1};
AbstractTensorPtr tensor(ctx->CreateTensor(DT_STRING, flat_shape));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create vector string tensor for checkpoint restore");
}
// Use placement new to construct the string, since we don't have
// access to Tensor::flat. This is conceptually equivalent to:
// tensor.flat<tstring>()(0) = s
new (tensor->Data()) tstring(s);
out->reset(ctx->CreateLocalHandle(tensor.get()));
return Status();
}
} // namespace
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
const std::string& checkpoint_key, DataType dtype,
ImmediateTensorHandlePtr* out) {
// Create the EagerOp
ImmediateOpPtr restore_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(restore_op->Reset("RestoreV2", "/cpu:0"));
TF_RETURN_IF_ERROR(restore_op->SetAttrTypeList("dtypes", &dtype, 1));
ImmediateTensorHandlePtr prefix_handle;
TF_RETURN_IF_ERROR(
CreateStringScalarTensorHandle(ctx, prefix, &prefix_handle));
ImmediateTensorHandlePtr names_handle;
TF_RETURN_IF_ERROR(
CreateStringVectorTensorHandle(ctx, checkpoint_key, &names_handle));
// Note that empty string is the slice spec used for a non-partitioned
// ResourceVariable:
// https://github.com/tensorflow/tensorflow/blob/06ff30f7ea35098cb68a231a9eb7ff3ff4be4e1e/tensorflow/python/training/saving/saveable_object_util.py#L194
ImmediateTensorHandlePtr shapes_and_slices_handle;
TF_RETURN_IF_ERROR(
CreateStringVectorTensorHandle(ctx, "", &shapes_and_slices_handle));
TF_RETURN_IF_ERROR(restore_op->AddInput(prefix_handle.get()));
TF_RETURN_IF_ERROR(restore_op->AddInput(names_handle.get()));
TF_RETURN_IF_ERROR(restore_op->AddInput(shapes_and_slices_handle.get()));
AbstractTensorHandle* restored_handle = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(restore_op->Execute(
absl::MakeSpan(&restored_handle, num_retvals), &num_retvals));
AbstractTensorHandlePtr owned_restored_handle(restored_handle);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
owned_restored_handle.get())) {
return errors::Internal("Unexpected tensor handle kind.");
}
out->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
owned_restored_handle.release()));
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace internal {
// TODO(bmzhao): Add a function to restore multiple tensors in one call.
// Restores a single non-partioned tensorhandle of dtype `dtype`, using
// checkpoint at `prefix`, with a value stored in `checkpoint_key`.
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
const std::string& checkpoint_key, DataType dtype,
ImmediateTensorHandlePtr* out);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
std::string CheckpointPrefix(StringPiece saved_model_dir) {
return io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
saved_model_dir, kSavedModelVariablesDirectory,
kSavedModelVariablesFilename);
}
class RestoreOpsTest : public ::testing::Test {
public:
RestoreOpsTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// One way of obtaining the checkpointa checkpoint's tensor names is:
// bazel run //tensorflow/python/tools:inspect_checkpoint -- --all_tensors
// --file_name="$CKPT_PREFIX".
// Here are the values for VarsAndArithmeticObjectGraph:
// tensor: child/z/.ATTRIBUTES/VARIABLE_VALUE (float32) []
// 3.0
// tensor: x/.ATTRIBUTES/VARIABLE_VALUE (float32) []
// 1.0
// tensor: y/.ATTRIBUTES/VARIABLE_VALUE (float32) []
// 2.0
TEST_F(RestoreOpsTest, RestoreSuccessful) {
ImmediateTensorHandlePtr x_handle;
TF_EXPECT_OK(internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle));
AbstractTensorPtr x = testing::TensorHandleToTensor(x_handle.get());
EXPECT_EQ(x->Type(), DT_FLOAT);
EXPECT_EQ(x->NumElements(), 1);
EXPECT_EQ(x->NumDims(), 0);
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(x->Data()), 1.0f);
ImmediateTensorHandlePtr y_handle;
TF_EXPECT_OK(internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"y/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &y_handle));
AbstractTensorPtr y = testing::TensorHandleToTensor(y_handle.get());
EXPECT_EQ(y->Type(), DT_FLOAT);
EXPECT_EQ(y->NumElements(), 1);
EXPECT_EQ(y->NumDims(), 0);
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(y->Data()), 2.0f);
ImmediateTensorHandlePtr z_handle;
TF_EXPECT_OK(internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"child/z/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &z_handle));
AbstractTensorPtr z = testing::TensorHandleToTensor(z_handle.get());
EXPECT_EQ(z->Type(), DT_FLOAT);
EXPECT_EQ(z->NumElements(), 1);
EXPECT_EQ(z->NumDims(), 0);
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(z->Data()), 3.0f);
}
TEST_F(RestoreOpsTest, BadCheckpointPrefixShouldFail) {
ImmediateTensorHandlePtr x_handle;
Status status = internal::SingleRestore(
context(), CheckpointPrefix("unknown_bad_checkpoint_prefix"),
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle);
EXPECT_FALSE(status.ok()) << status.error_message();
}
TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) {
ImmediateTensorHandlePtr x_handle;
Status status = internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"bad_checkpoint_key", DT_FLOAT, &x_handle);
EXPECT_FALSE(status.ok()) << status.error_message();
}
} // namespace
} // namespace tensorflow

View File

@ -139,5 +139,13 @@ void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
}
}
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle) {
Status status;
AbstractTensorPtr tensor(handle->Resolve(&status));
CHECK(status.ok()) << status.error_message();
CHECK_NE(tensor.get(), nullptr);
return tensor;
}
} // namespace testing
} // namespace tensorflow

View File

@ -69,6 +69,10 @@ void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b);
// Converts a TensorHandle to a Tensor, and dies if unsuccessful. This should
// only be used for testing purposes.
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle);
} // namespace testing
} // namespace tensorflow