Adding Variable class for Variable Reloading in the SavedModel C API.

PiperOrigin-RevId: 318199820
Change-Id: I901124780f8687d0f572cff4546f8792f1120e47
This commit is contained in:
Brian Zhao 2020-06-24 21:04:07 -07:00 committed by TensorFlower Gardener
parent 772b836fdb
commit 86fc04ef9b
11 changed files with 705 additions and 204 deletions

View File

@ -63,10 +63,34 @@ cc_library(
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "test_utils",
testonly = True,
srcs = [
"test_utils.cc",
],
hdrs = [
"test_utils.h",
],
deps = [
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_saved_model_impl",
srcs = [
@ -106,15 +130,35 @@ filegroup(
)
tf_cc_test(
name = "saved_model_utils_test",
name = "constant_loading_test",
srcs = [
"saved_model_utils_test.cc",
"constant_loading_test.cc",
],
deps = [
":saved_model_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
":test_utils",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test(
name = "saved_variable_loading_test",
srcs = [
"saved_variable_loading_test.cc",
],
deps = [
":saved_model_utils",
":test_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:framework",

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class ConstantTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>, bool>> {
public:
ConstantTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
// preserves values.
TEST_P(ConstantTest, CreateConstantSuccessful) {
// Get test parameters
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
bool tensorproto_use_tensor_content = std::get<2>(test_params);
// Construct a Tensor with the given dtype + shape
Tensor expected(dtype, shape);
testing::FillNumericTensorBuffer(expected.dtype(), expected.NumElements(),
expected.data(), 42);
// Serialize it to a Tensorproto
TensorProto proto;
if (tensorproto_use_tensor_content) {
expected.AsProtoTensorContent(&proto);
} else {
expected.AsProtoField(&proto);
}
// Revival should succeed w/o errors
std::unique_ptr<Constant> revived;
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
// The revived tensorhandle should have the exact same dtype, shape, +
// approx equivalent data to the original.
ImmediateExecutionTensorHandle* handle = revived->handle();
Status status;
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
for (int i = 0; i < expected.dims(); ++i) {
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
}
testing::CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
revived_tensor->Data(), expected.data());
}
// Test against combinations of tensors that are
// 1. Varying dtypes
// 2. Varying shapes
// 3. TensorProto serialized using tensor_content vs repeated type
INSTANTIATE_TEST_SUITE_P(
ConstantIntegerDtypesTest, ConstantTest,
::testing::Combine(
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(testing::InterestingShapes()),
::testing::Values(false, true)));
INSTANTIATE_TEST_SUITE_P(
ConstantFloatingDtypesTest, ConstantTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(testing::InterestingShapes()),
::testing::Values(false, true)));
} // namespace
} // namespace tensorflow

View File

@ -28,6 +28,27 @@ cc_library(
],
)
cc_library(
name = "variable",
srcs = [
"variable.cc",
],
hdrs = [
"variable.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "tensorhandle_convertible",
hdrs = [

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
Variable::Variable(ImmediateExecutionContext* ctx, DataType dtype,
TensorShape shape, absl::optional<std::string> name,
ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)),
name_(name.has_value() ? *name : "Variable"),
dtype_(dtype),
shape_(shape),
ctx_(ctx) {}
Variable::~Variable() {
// If the handle is null (perhaps because variable was std::moved from), then
// we don't have to do anything.
if (handle_ == nullptr) {
return;
}
Status status = internal::DestroyResource(ctx_, handle_.get());
if (!status.ok()) {
LOG(ERROR) << "Error destroying variable: " << name_
<< "due to: " << status;
}
}
DataType Variable::dtype() { return dtype_; }
TensorShape Variable::shape() { return shape_; }
Status Variable::Assign(ImmediateExecutionTensorHandle* handle) {
return internal::AssignVariable(ctx_, handle_.get(), dtype_, handle);
}
Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
}
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
std::unique_ptr<Variable>* output) {
ImmediateTensorHandlePtr handle;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, &handle));
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
#include <memory>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
class Variable : public TensorHandleConvertible {
public:
// Creates an uninitialized resource variable. Note that a caller must
// call "assign" to associate a value with the variable.
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
std::unique_ptr<Variable>* output);
// The dtype of the underlying variable.
DataType dtype();
// The shape of the underlying variable.
TensorShape shape();
// Updates the variable's contents with `handle`.
Status Assign(ImmediateExecutionTensorHandle* handle);
// Reads the value of the variable, and stores it in `out`
Status ReadValue(ImmediateTensorHandlePtr* out);
// Variable is movable, but not copyable.
Variable(Variable&& other) = default;
Variable& operator=(Variable&& other) = default;
~Variable() override;
private:
Variable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, ImmediateTensorHandlePtr handle);
Variable(const Variable& variable) = delete;
Variable& operator=(const Variable&) = delete;
std::string name_;
DataType dtype_;
TensorShape shape_;
// ctx_ must outlive Variable.
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_

View File

@ -15,8 +15,13 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace internal {
@ -34,5 +39,20 @@ Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
return Constant::Create(ctx, &tensor_interface, output);
}
// This follows the python variable restoration logic:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output) {
const std::string& name = variable.name();
tensorflow::TensorShape shape(variable.shape());
tensorflow::DataType dtype = variable.dtype();
TF_RETURN_IF_ERROR(
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -21,7 +21,9 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace internal {
@ -33,6 +35,14 @@ Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output);
// Creates a tensorflow::Variable from a SavedVariable. This is similar to the
// logic in:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
// Note that the caller **must assign a value** to the loaded variable.
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output);
} // namespace internal
} // namespace tensorflow

View File

@ -1,199 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include <string.h>
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
// This is needed for GTest's ::testing::ValuesIn, since
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
std::vector<DataType> result;
result.reserve(set.size());
for (DataType dt : set) {
result.push_back(dt);
}
return result;
}
// Returns a vector of shapes intended to be "interesting" test cases.
std::vector<std::vector<int64>> InterestingShapes() {
std::vector<std::vector<int64>> interesting_shapes;
interesting_shapes.push_back({}); // Scalar
interesting_shapes.push_back({10}); // 1D Vector
interesting_shapes.push_back({3, 3}); // 2D Matrix
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
return interesting_shapes;
}
// Fills a numeric tensor with `value`.
void FillNumericTensor(Tensor* tensor, int8 value) {
switch (tensor->dtype()) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
const auto& flattened = tensor->flat<type>(); \
for (int i = 0; i < tensor->NumElements(); ++i) { \
flattened(i) = value; \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: "
<< DataTypeString(tensor->dtype());
break;
}
}
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_a = static_cast<type*>(a); \
type* typed_b = static_cast<type*>(b); \
for (int64 i = 0; i < num_elements; ++i) { \
if (DataTypeIsFloating(dtype)) { \
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
} else { \
EXPECT_EQ(typed_a[i], typed_b[i]); \
} \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
}
}
class ConstantTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>, bool>> {
public:
ConstantTest()
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0"))),
ctx_(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr)) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
// preserves values.
TEST_P(ConstantTest, CreateConstantSuccessful) {
// Get test parameters
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
bool tensorproto_use_tensor_content = std::get<2>(test_params);
// Construct a Tensor with the given dtype + shape
Tensor expected(dtype, shape);
FillNumericTensor(&expected, 42);
// Serialize it to a Tensorproto
TensorProto proto;
if (tensorproto_use_tensor_content) {
expected.AsProtoTensorContent(&proto);
} else {
expected.AsProtoField(&proto);
}
// Revival should succeed w/o errors
std::unique_ptr<Constant> revived;
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
// The revived tensorhandle should have the exact same dtype, shape, +
// approx equivalent data to the original.
ImmediateExecutionTensorHandle* handle = revived->handle();
Status status;
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
for (int i = 0; i < expected.dims(); ++i) {
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
}
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
revived_tensor->Data(), expected.data());
}
// Test against combinations of tensors that are
// 1. Varying dtypes
// 2. Varying shapes
// 3. TensorProto serialized using tensor_content vs repeated type
INSTANTIATE_TEST_SUITE_P(
ConstantIntegerDtypesTest, ConstantTest,
::testing::Combine(
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(InterestingShapes()),
::testing::Values(false, true)));
INSTANTIATE_TEST_SUITE_P(
ConstantFloatingDtypesTest, ConstantTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(InterestingShapes()),
::testing::Values(false, true)));
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,122 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace {
class SavedVariableLoadingTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>>> {
public:
SavedVariableLoadingTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Sanity check that constructing a tensorflow::Variable from a SavedVariable
// 1. does not cause an error
// 2. preserves dtype and shape.
TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
SavedVariable saved_variable;
saved_variable.set_dtype(dtype);
shape.AsProto(saved_variable.mutable_shape());
std::unique_ptr<Variable> var;
TF_EXPECT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
EXPECT_EQ(var->dtype(), dtype);
EXPECT_EQ(var->shape(), shape);
}
// Assigning and reading values should yield
// consistent results.
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
std::vector<int64> shape_vector = std::get<1>(test_params);
TensorShape shape(shape_vector);
// Create the variable.
Status status;
std::unique_ptr<Variable> var;
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
absl::nullopt, &var));
// Create a TensorHandle
ImmediateTensorHandlePtr expected_handle =
testing::CreateTensorHandle(context(), dtype, shape_vector, 42);
AbstractTensorPtr expected_tensor(expected_handle->Resolve(&status));
TF_EXPECT_OK(status) << status.error_message();
// Assign the tensorhandle to the variable.
TF_EXPECT_OK(var->Assign(expected_handle.get()));
// Read back the value from the variable
ImmediateTensorHandlePtr output_handle;
TF_EXPECT_OK(var->ReadValue(&output_handle));
AbstractTensorPtr output_tensor(output_handle->Resolve(&status));
TF_EXPECT_OK(status) << status.error_message();
// Check that output_tensor == expected_tensor
EXPECT_EQ(output_tensor->Type(), expected_tensor->Type());
EXPECT_EQ(output_tensor->NumElements(), expected_tensor->NumElements());
testing::CheckBufferDataIsEqual(
output_tensor->Type(), output_tensor->NumElements(),
output_tensor->Data(), expected_tensor->Data());
}
// Test against combinations of SavedVariables of
// 1. Varying dtypes
// 2. Varying shapes
INSTANTIATE_TEST_SUITE_P(
SavedVariableIntegerDtypesTest, SavedVariableLoadingTest,
::testing::Combine(
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(testing::InterestingShapes())));
INSTANTIATE_TEST_SUITE_P(
SavedVariableFloatingDtypesTest, SavedVariableLoadingTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(testing::InterestingShapes())));
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace testing {
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr() {
return std::make_unique<StaticDeviceMgr>(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
}
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
return EagerContextPtr(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr));
}
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
std::vector<DataType> result;
result.reserve(set.size());
for (DataType dt : set) {
result.push_back(dt);
}
return result;
}
std::vector<std::vector<int64>> InterestingShapes() {
std::vector<std::vector<int64>> interesting_shapes;
interesting_shapes.push_back({}); // Scalar
interesting_shapes.push_back({10}); // 1D Vector
interesting_shapes.push_back({3, 3}); // 2D Matrix
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
return interesting_shapes;
}
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
DataType dtype,
absl::Span<const int64> shape,
int8 value) {
AbstractTensorPtr tensor(ctx->CreateTensor(dtype, shape));
CHECK_NE(tensor.get(), nullptr)
<< "Tensor creation failed for tensor of dtype: "
<< DataTypeString(dtype);
CHECK_EQ(tensor->Type(), dtype);
for (int i = 0; i < shape.size(); ++i) {
CHECK_EQ(tensor->Dim(i), shape[i]);
}
FillNumericTensorBuffer(tensor->Type(), tensor->NumElements(), tensor->Data(),
value);
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
CHECK_NE(handle.get(), nullptr);
return handle;
}
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
int8 value) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_buffer = static_cast<type*>(buffer); \
for (size_t i = 0; i < num_elements; ++i) { \
typed_buffer[i] = value; \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
break;
}
}
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_a = static_cast<type*>(a); \
type* typed_b = static_cast<type*>(b); \
for (int64 i = 0; i < num_elements; ++i) { \
if (DataTypeIsFloating(dtype)) { \
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
} else { \
EXPECT_EQ(typed_a[i], typed_b[i]); \
} \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
}
}
} // namespace testing
} // namespace tensorflow

View File

@ -0,0 +1,75 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace testing {
// Creates a DeviceMgr suitable for local tests.
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr();
// Creates an EagerContext suitable for local tests. Does not take ownership
// of `device_mgr`.
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr);
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
// This is useful for tests using GTest's ::testing::ValuesIn, since
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
std::vector<DataType> DataTypeSetToVector(DataTypeSet set);
// Returns a vector of shapes intended to be "interesting" test cases.
// Currently, this returns scalar, 1D vector, 2D matrix, and a 4D tensor shapes
std::vector<std::vector<int64>> InterestingShapes();
// Returns a TensorHandle of `dtype` and `shape`, filled with `value`.
// `dtype` must be an integer dtype, float, or double.
// If a TensorHandle cannot be created successfully, this function will
// CHECK fail. This should only be used for testing purposes.
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
DataType dtype,
absl::Span<const int64> shape,
int8 value);
// Fills a numeric tensor's buffer with `value`.
// dtype must be any integer dtype, float or double.
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
int8 value);
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
// dtype must be any integer dtype, float, or double.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b);
} // namespace testing
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_