Adding RevivedConstant class for Constant reloading in the SavedModelAPI C API.

PiperOrigin-RevId: 317920112
Change-Id: I2dc84de102c1edc5513df319e66ee20351bdb725
This commit is contained in:
Brian Zhao 2020-06-23 12:33:36 -07:00 committed by TensorFlower Gardener
parent 7d025c63c5
commit 95428e83f5
8 changed files with 508 additions and 0 deletions

View File

@ -3,6 +3,10 @@
# Targets in this directory are pure C++ "Classes" underlying the C API types
# under tf/c/experimental/saved_model/public/. They are subject to change and
# have visibility limited to Tensorflow's implementation only.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
default_visibility = [
@ -47,6 +51,22 @@ cc_library(
],
)
cc_library(
name = "saved_model_utils",
srcs = [
"saved_model_utils.cc",
],
hdrs = [
"saved_model_utils.h",
],
deps = [
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tf_saved_model_impl",
srcs = [
@ -84,3 +104,26 @@ filegroup(
],
visibility = ["//tensorflow/core:__pkg__"],
)
tf_cc_test(
name = "saved_model_utils_test",
srcs = [
"saved_model_utils_test.cc",
],
deps = [
":saved_model_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -0,0 +1,39 @@
# This package contains classes corresponding to Revived SavedObjectGraph types
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
package(
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "constant",
srcs = [
"constant.cc",
],
hdrs = [
"constant.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "tensorhandle_convertible",
hdrs = [
"tensorhandle_convertible.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
Constant::Constant(ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)) {}
Status Constant::Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output) {
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
if (handle == nullptr) {
return errors::Internal("Failed to convert tensor to tensorhandle");
}
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
// This class corresponds to python's tf.constant, which is effectively a
// TensorHandle explicitly initialized to some value.
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
// and offer a subclass of TensorHandleConvertible. Note that similar to
// the python's eager mode logic, we bypass calling the "Const" op:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
class Constant : public TensorHandleConvertible {
public:
static Status Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output);
// RevivedConstant is movable, but not copyable.
Constant(Constant&& other) = default;
Constant& operator=(Constant&& other) = default;
~Constant() override = default;
private:
explicit Constant(ImmediateTensorHandlePtr handle);
Constant(const Constant&) = delete;
Constant& operator=(const Constant&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
namespace tensorflow {
// A common interface for objects that can be converted to a TensorHandle.
// Examples of objects that implement this include Variables, Constants, Assets,
// etc. This is used to convert captured objects into a ConcreteFunction's
// captured TensorHandles:
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
class TensorHandleConvertible {
public:
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
: handle_(std::move(handle)) {}
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
// TensorHandleConvertible is movable, but not copyable.
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
virtual ~TensorHandleConvertible() = default;
protected:
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
ImmediateTensorHandlePtr handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/tf_tensor_internal.h"
namespace tensorflow {
namespace internal {
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output) {
tensorflow::Tensor tensor;
bool parse_result = tensor.FromProto(proto);
if (!parse_result) {
return errors::Internal("Failed to parse tensor from tensorproto");
}
TensorInterface tensor_interface(std::move(tensor));
return Constant::Create(ctx, &tensor_interface, output);
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
// Some internal utility functions for the SavedModelAPI, factored out into a
// separately unit-testable header.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
namespace internal {
// Load a TensorProto into a tensorflow::Constant. This is similar to the
// constant loading logic in python:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_

View File

@ -0,0 +1,199 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include <string.h>
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
// This is needed for GTest's ::testing::ValuesIn, since
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
std::vector<DataType> result;
result.reserve(set.size());
for (DataType dt : set) {
result.push_back(dt);
}
return result;
}
// Returns a vector of shapes intended to be "interesting" test cases.
std::vector<std::vector<int64>> InterestingShapes() {
std::vector<std::vector<int64>> interesting_shapes;
interesting_shapes.push_back({}); // Scalar
interesting_shapes.push_back({10}); // 1D Vector
interesting_shapes.push_back({3, 3}); // 2D Matrix
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
return interesting_shapes;
}
// Fills a numeric tensor with `value`.
void FillNumericTensor(Tensor* tensor, int8 value) {
switch (tensor->dtype()) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
const auto& flattened = tensor->flat<type>(); \
for (int i = 0; i < tensor->NumElements(); ++i) { \
flattened(i) = value; \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: "
<< DataTypeString(tensor->dtype());
break;
}
}
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_a = static_cast<type*>(a); \
type* typed_b = static_cast<type*>(b); \
for (int64 i = 0; i < num_elements; ++i) { \
if (DataTypeIsFloating(dtype)) { \
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
} else { \
EXPECT_EQ(typed_a[i], typed_b[i]); \
} \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
}
}
class ConstantTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>, bool>> {
public:
ConstantTest()
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0"))),
ctx_(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr)) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
// preserves values.
TEST_P(ConstantTest, CreateConstantSuccessful) {
// Get test parameters
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
bool tensorproto_use_tensor_content = std::get<2>(test_params);
// Construct a Tensor with the given dtype + shape
Tensor expected(dtype, shape);
FillNumericTensor(&expected, 42);
// Serialize it to a Tensorproto
TensorProto proto;
if (tensorproto_use_tensor_content) {
expected.AsProtoTensorContent(&proto);
} else {
expected.AsProtoField(&proto);
}
// Revival should succeed w/o errors
std::unique_ptr<Constant> revived;
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
// The revived tensorhandle should have the exact same dtype, shape, +
// approx equivalent data to the original.
ImmediateExecutionTensorHandle* handle = revived->handle();
Status status;
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
for (int i = 0; i < expected.dims(); ++i) {
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
}
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
revived_tensor->Data(), expected.data());
}
// Test against combinations of tensors that are
// 1. Varying dtypes
// 2. Varying shapes
// 3. TensorProto serialized using tensor_content vs repeated type
INSTANTIATE_TEST_SUITE_P(
ConstantIntegerDtypesTest, ConstantTest,
::testing::Combine(
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(InterestingShapes()),
::testing::Values(false, true)));
INSTANTIATE_TEST_SUITE_P(
ConstantFloatingDtypesTest, ConstantTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(InterestingShapes()),
::testing::Values(false, true)));
} // namespace
} // namespace tensorflow