From 087d3651ba775847ee76d139f2e8adc416f969ad Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Tue, 7 Jul 2020 20:30:42 -0700 Subject: [PATCH] Adding convenience functions for calling restore ops. These are needed for implementing the SavedModel C API. PiperOrigin-RevId: 320114439 Change-Id: I707e9e228e509ccf37202fa7862c92e7ee78f46d --- .../c/experimental/saved_model/core/ops/BUILD | 49 ++++++++ .../saved_model/core/ops/restore_ops.cc | 111 ++++++++++++++++++ .../saved_model/core/ops/restore_ops.h | 40 +++++++ .../saved_model/core/ops/restore_ops_test.cc | 111 ++++++++++++++++++ .../saved_model/core/test_utils.cc | 8 ++ .../saved_model/core/test_utils.h | 4 + 6 files changed, 323 insertions(+) create mode 100644 tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc create mode 100644 tensorflow/c/experimental/saved_model/core/ops/restore_ops.h create mode 100644 tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index 34439699522..673ea1a80e2 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -14,6 +14,27 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "restore_ops", + srcs = [ + "restore_ops.cc", + ], + hdrs = [ + "restore_ops.h", + ], + deps = [ + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + ], +) + cc_library( name = "variable_ops", srcs = [ @@ -36,6 +57,34 @@ cc_library( ], ) +tf_cc_test( + name = "restore_ops_test", + srcs = [ + "restore_ops_test.cc", + ], + data = [ + "//tensorflow/cc/saved_model:saved_model_half_plus_two", + ], + deps = [ + ":restore_ops", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:test_utils", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:core_cpu_lib", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:core", + ], +) + tf_cc_test( name = "variable_ops_test", srcs = [ diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc new file mode 100644 index 00000000000..6609ecee508 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" + +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace internal { + +namespace { + +// Creates a scalar string tensorhandle containing a single string `s` +Status CreateStringScalarTensorHandle(ImmediateExecutionContext* ctx, + const std::string& s, + ImmediateTensorHandlePtr* out) { + AbstractTensorPtr tensor(ctx->CreateStringScalar(s)); + if (tensor.get() == nullptr) { + return errors::Internal( + "Failed to create scalar string tensor for checkpoint restore"); + } + + out->reset(ctx->CreateLocalHandle(tensor.get())); + return Status(); +} + +// Creates a Rank 1 string tensorhandle containing a single string `s` +Status CreateStringVectorTensorHandle(ImmediateExecutionContext* ctx, + const std::string& s, + ImmediateTensorHandlePtr* out) { + int64 flat_shape[] = {1}; + AbstractTensorPtr tensor(ctx->CreateTensor(DT_STRING, flat_shape)); + if (tensor.get() == nullptr) { + return errors::Internal( + "Failed to create vector string tensor for checkpoint restore"); + } + // Use placement new to construct the string, since we don't have + // access to Tensor::flat. This is conceptually equivalent to: + // tensor.flat()(0) = s + new (tensor->Data()) tstring(s); + + out->reset(ctx->CreateLocalHandle(tensor.get())); + return Status(); +} + +} // namespace + +Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix, + const std::string& checkpoint_key, DataType dtype, + ImmediateTensorHandlePtr* out) { + // Create the EagerOp + ImmediateOpPtr restore_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(restore_op->Reset("RestoreV2", "/cpu:0")); + TF_RETURN_IF_ERROR(restore_op->SetAttrTypeList("dtypes", &dtype, 1)); + + ImmediateTensorHandlePtr prefix_handle; + TF_RETURN_IF_ERROR( + CreateStringScalarTensorHandle(ctx, prefix, &prefix_handle)); + + ImmediateTensorHandlePtr names_handle; + TF_RETURN_IF_ERROR( + CreateStringVectorTensorHandle(ctx, checkpoint_key, &names_handle)); + + // Note that empty string is the slice spec used for a non-partitioned + // ResourceVariable: + // https://github.com/tensorflow/tensorflow/blob/06ff30f7ea35098cb68a231a9eb7ff3ff4be4e1e/tensorflow/python/training/saving/saveable_object_util.py#L194 + ImmediateTensorHandlePtr shapes_and_slices_handle; + TF_RETURN_IF_ERROR( + CreateStringVectorTensorHandle(ctx, "", &shapes_and_slices_handle)); + + TF_RETURN_IF_ERROR(restore_op->AddInput(prefix_handle.get())); + TF_RETURN_IF_ERROR(restore_op->AddInput(names_handle.get())); + TF_RETURN_IF_ERROR(restore_op->AddInput(shapes_and_slices_handle.get())); + + AbstractTensorHandle* restored_handle = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR(restore_op->Execute( + absl::MakeSpan(&restored_handle, num_retvals), &num_retvals)); + AbstractTensorHandlePtr owned_restored_handle(restored_handle); + if (!tensorflow::isa( + owned_restored_handle.get())) { + return errors::Internal("Unexpected tensor handle kind."); + } + out->reset(reinterpret_cast( + owned_restored_handle.release())); + return Status(); +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h new file mode 100644 index 00000000000..f215bc9e7ab --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace internal { + +// TODO(bmzhao): Add a function to restore multiple tensors in one call. + +// Restores a single non-partioned tensorhandle of dtype `dtype`, using +// checkpoint at `prefix`, with a value stored in `checkpoint_key`. +Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix, + const std::string& checkpoint_key, DataType dtype, + ImmediateTensorHandlePtr* out); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_ diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc new file mode 100644 index 00000000000..52a652a90ef --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" + +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/test_utils.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +std::string CheckpointPrefix(StringPiece saved_model_dir) { + return io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", + saved_model_dir, kSavedModelVariablesDirectory, + kSavedModelVariablesFilename); +} + +class RestoreOpsTest : public ::testing::Test { + public: + RestoreOpsTest() + : device_mgr_(testing::CreateTestingDeviceMgr()), + ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} + + EagerContext* context() { return ctx_.get(); } + + private: + std::unique_ptr device_mgr_; + EagerContextPtr ctx_; +}; + +// One way of obtaining the checkpointa checkpoint's tensor names is: +// bazel run //tensorflow/python/tools:inspect_checkpoint -- --all_tensors +// --file_name="$CKPT_PREFIX". +// Here are the values for VarsAndArithmeticObjectGraph: +// tensor: child/z/.ATTRIBUTES/VARIABLE_VALUE (float32) [] +// 3.0 +// tensor: x/.ATTRIBUTES/VARIABLE_VALUE (float32) [] +// 1.0 +// tensor: y/.ATTRIBUTES/VARIABLE_VALUE (float32) [] +// 2.0 + +TEST_F(RestoreOpsTest, RestoreSuccessful) { + ImmediateTensorHandlePtr x_handle; + TF_EXPECT_OK(internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle)); + AbstractTensorPtr x = testing::TensorHandleToTensor(x_handle.get()); + EXPECT_EQ(x->Type(), DT_FLOAT); + EXPECT_EQ(x->NumElements(), 1); + EXPECT_EQ(x->NumDims(), 0); + EXPECT_FLOAT_EQ(*reinterpret_cast(x->Data()), 1.0f); + + ImmediateTensorHandlePtr y_handle; + TF_EXPECT_OK(internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "y/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &y_handle)); + AbstractTensorPtr y = testing::TensorHandleToTensor(y_handle.get()); + EXPECT_EQ(y->Type(), DT_FLOAT); + EXPECT_EQ(y->NumElements(), 1); + EXPECT_EQ(y->NumDims(), 0); + EXPECT_FLOAT_EQ(*reinterpret_cast(y->Data()), 2.0f); + + ImmediateTensorHandlePtr z_handle; + TF_EXPECT_OK(internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "child/z/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &z_handle)); + AbstractTensorPtr z = testing::TensorHandleToTensor(z_handle.get()); + EXPECT_EQ(z->Type(), DT_FLOAT); + EXPECT_EQ(z->NumElements(), 1); + EXPECT_EQ(z->NumDims(), 0); + EXPECT_FLOAT_EQ(*reinterpret_cast(z->Data()), 3.0f); +} + +TEST_F(RestoreOpsTest, BadCheckpointPrefixShouldFail) { + ImmediateTensorHandlePtr x_handle; + Status status = internal::SingleRestore( + context(), CheckpointPrefix("unknown_bad_checkpoint_prefix"), + "x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle); + EXPECT_FALSE(status.ok()) << status.error_message(); +} + +TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) { + ImmediateTensorHandlePtr x_handle; + Status status = internal::SingleRestore( + context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), + "bad_checkpoint_key", DT_FLOAT, &x_handle); + EXPECT_FALSE(status.ok()) << status.error_message(); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index 920b7dd0139..b803d129b90 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -139,5 +139,13 @@ void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a, } } +AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle) { + Status status; + AbstractTensorPtr tensor(handle->Resolve(&status)); + CHECK(status.ok()) << status.error_message(); + CHECK_NE(tensor.get(), nullptr); + return tensor; +} + } // namespace testing } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.h b/tensorflow/c/experimental/saved_model/core/test_utils.h index fe80a660649..bdc1ca762ee 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.h +++ b/tensorflow/c/experimental/saved_model/core/test_utils.h @@ -69,6 +69,10 @@ void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer, void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a, void* b); +// Converts a TensorHandle to a Tensor, and dies if unsuccessful. This should +// only be used for testing purposes. +AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle); + } // namespace testing } // namespace tensorflow