Adding Variable class for Variable Reloading in the SavedModel C API.
PiperOrigin-RevId: 318199820 Change-Id: I901124780f8687d0f572cff4546f8792f1120e47
This commit is contained in:
parent
772b836fdb
commit
86fc04ef9b
tensorflow/c/experimental/saved_model/core
@ -63,10 +63,34 @@ cc_library(
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_utils",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"test_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"test_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
@ -106,15 +130,35 @@ filegroup(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_utils_test",
|
||||
name = "constant_loading_test",
|
||||
srcs = [
|
||||
"saved_model_utils_test.cc",
|
||||
"constant_loading_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
":test_utils",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_variable_loading_test",
|
||||
srcs = [
|
||||
"saved_variable_loading_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
":test_utils",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class ConstantTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>, bool>> {
|
||||
public:
|
||||
ConstantTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
||||
// preserves values.
|
||||
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
||||
// Get test parameters
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
||||
|
||||
// Construct a Tensor with the given dtype + shape
|
||||
Tensor expected(dtype, shape);
|
||||
testing::FillNumericTensorBuffer(expected.dtype(), expected.NumElements(),
|
||||
expected.data(), 42);
|
||||
|
||||
// Serialize it to a Tensorproto
|
||||
TensorProto proto;
|
||||
if (tensorproto_use_tensor_content) {
|
||||
expected.AsProtoTensorContent(&proto);
|
||||
} else {
|
||||
expected.AsProtoField(&proto);
|
||||
}
|
||||
|
||||
// Revival should succeed w/o errors
|
||||
std::unique_ptr<Constant> revived;
|
||||
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
||||
|
||||
// The revived tensorhandle should have the exact same dtype, shape, +
|
||||
// approx equivalent data to the original.
|
||||
ImmediateExecutionTensorHandle* handle = revived->handle();
|
||||
Status status;
|
||||
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
||||
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
||||
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
||||
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
||||
for (int i = 0; i < expected.dims(); ++i) {
|
||||
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
||||
}
|
||||
|
||||
testing::CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
||||
revived_tensor->Data(), expected.data());
|
||||
}
|
||||
|
||||
// Test against combinations of tensors that are
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
// 3. TensorProto serialized using tensor_content vs repeated type
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantIntegerDtypesTest, ConstantTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(testing::InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantFloatingDtypesTest, ConstantTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(testing::InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -28,6 +28,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable",
|
||||
srcs = [
|
||||
"variable.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"variable.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_convertible",
|
||||
hdrs = [
|
||||
|
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Variable::Variable(ImmediateExecutionContext* ctx, DataType dtype,
|
||||
TensorShape shape, absl::optional<std::string> name,
|
||||
ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)),
|
||||
name_(name.has_value() ? *name : "Variable"),
|
||||
dtype_(dtype),
|
||||
shape_(shape),
|
||||
ctx_(ctx) {}
|
||||
|
||||
Variable::~Variable() {
|
||||
// If the handle is null (perhaps because variable was std::moved from), then
|
||||
// we don't have to do anything.
|
||||
if (handle_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
Status status = internal::DestroyResource(ctx_, handle_.get());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error destroying variable: " << name_
|
||||
<< "due to: " << status;
|
||||
}
|
||||
}
|
||||
|
||||
DataType Variable::dtype() { return dtype_; }
|
||||
|
||||
TensorShape Variable::shape() { return shape_; }
|
||||
|
||||
Status Variable::Assign(ImmediateExecutionTensorHandle* handle) {
|
||||
return internal::AssignVariable(ctx_, handle_.get(), dtype_, handle);
|
||||
}
|
||||
|
||||
Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
||||
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||
}
|
||||
|
||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, &handle));
|
||||
|
||||
output->reset(
|
||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Variable : public TensorHandleConvertible {
|
||||
public:
|
||||
// Creates an uninitialized resource variable. Note that a caller must
|
||||
// call "assign" to associate a value with the variable.
|
||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
// The dtype of the underlying variable.
|
||||
DataType dtype();
|
||||
|
||||
// The shape of the underlying variable.
|
||||
TensorShape shape();
|
||||
|
||||
// Updates the variable's contents with `handle`.
|
||||
Status Assign(ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
// Reads the value of the variable, and stores it in `out`
|
||||
Status ReadValue(ImmediateTensorHandlePtr* out);
|
||||
|
||||
// Variable is movable, but not copyable.
|
||||
Variable(Variable&& other) = default;
|
||||
Variable& operator=(Variable&& other) = default;
|
||||
|
||||
~Variable() override;
|
||||
|
||||
private:
|
||||
Variable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name, ImmediateTensorHandlePtr handle);
|
||||
Variable(const Variable& variable) = delete;
|
||||
Variable& operator=(const Variable&) = delete;
|
||||
|
||||
std::string name_;
|
||||
DataType dtype_;
|
||||
TensorShape shape_;
|
||||
|
||||
// ctx_ must outlive Variable.
|
||||
ImmediateExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
@ -15,8 +15,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
@ -34,5 +39,20 @@ Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
return Constant::Create(ctx, &tensor_interface, output);
|
||||
}
|
||||
|
||||
// This follows the python variable restoration logic:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
|
||||
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const SavedVariable& variable,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
const std::string& name = variable.name();
|
||||
tensorflow::TensorShape shape(variable.shape());
|
||||
tensorflow::DataType dtype = variable.dtype();
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
@ -21,7 +21,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
@ -33,6 +35,14 @@ Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
// Creates a tensorflow::Variable from a SavedVariable. This is similar to the
|
||||
// logic in:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
|
||||
// Note that the caller **must assign a value** to the loaded variable.
|
||||
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const SavedVariable& variable,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1,199 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
||||
// This is needed for GTest's ::testing::ValuesIn, since
|
||||
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
||||
std::vector<DataType> result;
|
||||
result.reserve(set.size());
|
||||
for (DataType dt : set) {
|
||||
result.push_back(dt);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a vector of shapes intended to be "interesting" test cases.
|
||||
std::vector<std::vector<int64>> InterestingShapes() {
|
||||
std::vector<std::vector<int64>> interesting_shapes;
|
||||
interesting_shapes.push_back({}); // Scalar
|
||||
interesting_shapes.push_back({10}); // 1D Vector
|
||||
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
||||
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
||||
return interesting_shapes;
|
||||
}
|
||||
|
||||
// Fills a numeric tensor with `value`.
|
||||
void FillNumericTensor(Tensor* tensor, int8 value) {
|
||||
switch (tensor->dtype()) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
const auto& flattened = tensor->flat<type>(); \
|
||||
for (int i = 0; i < tensor->NumElements(); ++i) { \
|
||||
flattened(i) = value; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: "
|
||||
<< DataTypeString(tensor->dtype());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_a = static_cast<type*>(a); \
|
||||
type* typed_b = static_cast<type*>(b); \
|
||||
for (int64 i = 0; i < num_elements; ++i) { \
|
||||
if (DataTypeIsFloating(dtype)) { \
|
||||
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
||||
} else { \
|
||||
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
class ConstantTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>, bool>> {
|
||||
public:
|
||||
ConstantTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
||||
// preserves values.
|
||||
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
||||
// Get test parameters
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
||||
|
||||
// Construct a Tensor with the given dtype + shape
|
||||
Tensor expected(dtype, shape);
|
||||
FillNumericTensor(&expected, 42);
|
||||
|
||||
// Serialize it to a Tensorproto
|
||||
TensorProto proto;
|
||||
if (tensorproto_use_tensor_content) {
|
||||
expected.AsProtoTensorContent(&proto);
|
||||
} else {
|
||||
expected.AsProtoField(&proto);
|
||||
}
|
||||
|
||||
// Revival should succeed w/o errors
|
||||
std::unique_ptr<Constant> revived;
|
||||
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
||||
|
||||
// The revived tensorhandle should have the exact same dtype, shape, +
|
||||
// approx equivalent data to the original.
|
||||
ImmediateExecutionTensorHandle* handle = revived->handle();
|
||||
Status status;
|
||||
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
||||
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
||||
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
||||
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
||||
for (int i = 0; i < expected.dims(); ++i) {
|
||||
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
||||
}
|
||||
|
||||
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
||||
revived_tensor->Data(), expected.data());
|
||||
}
|
||||
|
||||
// Test against combinations of tensors that are
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
// 3. TensorProto serialized using tensor_content vs repeated type
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantIntegerDtypesTest, ConstantTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantFloatingDtypesTest, ConstantTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -0,0 +1,122 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class SavedVariableLoadingTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>>> {
|
||||
public:
|
||||
SavedVariableLoadingTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Sanity check that constructing a tensorflow::Variable from a SavedVariable
|
||||
// 1. does not cause an error
|
||||
// 2. preserves dtype and shape.
|
||||
TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
|
||||
SavedVariable saved_variable;
|
||||
saved_variable.set_dtype(dtype);
|
||||
shape.AsProto(saved_variable.mutable_shape());
|
||||
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_EXPECT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
|
||||
EXPECT_EQ(var->dtype(), dtype);
|
||||
EXPECT_EQ(var->shape(), shape);
|
||||
}
|
||||
|
||||
// Assigning and reading values should yield
|
||||
// consistent results.
|
||||
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
std::vector<int64> shape_vector = std::get<1>(test_params);
|
||||
TensorShape shape(shape_vector);
|
||||
|
||||
// Create the variable.
|
||||
Status status;
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||
absl::nullopt, &var));
|
||||
|
||||
// Create a TensorHandle
|
||||
ImmediateTensorHandlePtr expected_handle =
|
||||
testing::CreateTensorHandle(context(), dtype, shape_vector, 42);
|
||||
AbstractTensorPtr expected_tensor(expected_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << status.error_message();
|
||||
|
||||
// Assign the tensorhandle to the variable.
|
||||
TF_EXPECT_OK(var->Assign(expected_handle.get()));
|
||||
|
||||
// Read back the value from the variable
|
||||
ImmediateTensorHandlePtr output_handle;
|
||||
TF_EXPECT_OK(var->ReadValue(&output_handle));
|
||||
AbstractTensorPtr output_tensor(output_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << status.error_message();
|
||||
|
||||
// Check that output_tensor == expected_tensor
|
||||
EXPECT_EQ(output_tensor->Type(), expected_tensor->Type());
|
||||
EXPECT_EQ(output_tensor->NumElements(), expected_tensor->NumElements());
|
||||
testing::CheckBufferDataIsEqual(
|
||||
output_tensor->Type(), output_tensor->NumElements(),
|
||||
output_tensor->Data(), expected_tensor->Data());
|
||||
}
|
||||
|
||||
// Test against combinations of SavedVariables of
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SavedVariableIntegerDtypesTest, SavedVariableLoadingTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(testing::InterestingShapes())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SavedVariableFloatingDtypesTest, SavedVariableLoadingTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(testing::InterestingShapes())));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
143
tensorflow/c/experimental/saved_model/core/test_utils.cc
Normal file
143
tensorflow/c/experimental/saved_model/core/test_utils.cc
Normal file
@ -0,0 +1,143 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
|
||||
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr() {
|
||||
return std::make_unique<StaticDeviceMgr>(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
|
||||
}
|
||||
|
||||
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
|
||||
return EagerContextPtr(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr));
|
||||
}
|
||||
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
||||
std::vector<DataType> result;
|
||||
result.reserve(set.size());
|
||||
for (DataType dt : set) {
|
||||
result.push_back(dt);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64>> InterestingShapes() {
|
||||
std::vector<std::vector<int64>> interesting_shapes;
|
||||
interesting_shapes.push_back({}); // Scalar
|
||||
interesting_shapes.push_back({10}); // 1D Vector
|
||||
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
||||
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
||||
return interesting_shapes;
|
||||
}
|
||||
|
||||
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
|
||||
DataType dtype,
|
||||
absl::Span<const int64> shape,
|
||||
int8 value) {
|
||||
AbstractTensorPtr tensor(ctx->CreateTensor(dtype, shape));
|
||||
CHECK_NE(tensor.get(), nullptr)
|
||||
<< "Tensor creation failed for tensor of dtype: "
|
||||
<< DataTypeString(dtype);
|
||||
CHECK_EQ(tensor->Type(), dtype);
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
CHECK_EQ(tensor->Dim(i), shape[i]);
|
||||
}
|
||||
FillNumericTensorBuffer(tensor->Type(), tensor->NumElements(), tensor->Data(),
|
||||
value);
|
||||
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
|
||||
CHECK_NE(handle.get(), nullptr);
|
||||
return handle;
|
||||
}
|
||||
|
||||
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||
int8 value) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_buffer = static_cast<type*>(buffer); \
|
||||
for (size_t i = 0; i < num_elements; ++i) { \
|
||||
typed_buffer[i] = value; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_a = static_cast<type*>(a); \
|
||||
type* typed_b = static_cast<type*>(b); \
|
||||
for (int64 i = 0; i < num_elements; ++i) { \
|
||||
if (DataTypeIsFloating(dtype)) { \
|
||||
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
||||
} else { \
|
||||
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
75
tensorflow/c/experimental/saved_model/core/test_utils.h
Normal file
75
tensorflow/c/experimental/saved_model/core/test_utils.h
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
|
||||
// Creates a DeviceMgr suitable for local tests.
|
||||
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr();
|
||||
|
||||
// Creates an EagerContext suitable for local tests. Does not take ownership
|
||||
// of `device_mgr`.
|
||||
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr);
|
||||
|
||||
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
||||
// This is useful for tests using GTest's ::testing::ValuesIn, since
|
||||
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set);
|
||||
|
||||
// Returns a vector of shapes intended to be "interesting" test cases.
|
||||
// Currently, this returns scalar, 1D vector, 2D matrix, and a 4D tensor shapes
|
||||
std::vector<std::vector<int64>> InterestingShapes();
|
||||
|
||||
// Returns a TensorHandle of `dtype` and `shape`, filled with `value`.
|
||||
// `dtype` must be an integer dtype, float, or double.
|
||||
// If a TensorHandle cannot be created successfully, this function will
|
||||
// CHECK fail. This should only be used for testing purposes.
|
||||
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
|
||||
DataType dtype,
|
||||
absl::Span<const int64> shape,
|
||||
int8 value);
|
||||
|
||||
// Fills a numeric tensor's buffer with `value`.
|
||||
// dtype must be any integer dtype, float or double.
|
||||
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||
int8 value);
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
// dtype must be any integer dtype, float, or double.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b);
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user