Adding Variable class for Variable Reloading in the SavedModel C API.
PiperOrigin-RevId: 318199820 Change-Id: I901124780f8687d0f572cff4546f8792f1120e47
This commit is contained in:
parent
772b836fdb
commit
86fc04ef9b
@ -63,10 +63,34 @@ cc_library(
|
|||||||
"//tensorflow/c:tf_tensor_internal",
|
"//tensorflow/c:tf_tensor_internal",
|
||||||
"//tensorflow/c/eager:immediate_execution_context",
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "test_utils",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"test_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"test_utils.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:tensor_interface",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||||
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf_saved_model_impl",
|
name = "tf_saved_model_impl",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -106,15 +130,35 @@ filegroup(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "saved_model_utils_test",
|
name = "constant_loading_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"saved_model_utils_test.cc",
|
"constant_loading_test.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":saved_model_utils",
|
":saved_model_utils",
|
||||||
"//tensorflow/c:tensor_interface",
|
":test_utils",
|
||||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
"//tensorflow/c/eager:immediate_execution_context",
|
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||||
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
|
"//tensorflow/core/common_runtime/eager:core",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "saved_variable_loading_test",
|
||||||
|
srcs = [
|
||||||
|
"saved_variable_loading_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":saved_model_utils",
|
||||||
|
":test_utils",
|
||||||
|
"//tensorflow/c:tensor_interface",
|
||||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -0,0 +1,111 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ConstantTest : public ::testing::TestWithParam<
|
||||||
|
std::tuple<DataType, std::vector<int64>, bool>> {
|
||||||
|
public:
|
||||||
|
ConstantTest()
|
||||||
|
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||||
|
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||||
|
|
||||||
|
EagerContext* context() { return ctx_.get(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||||
|
EagerContextPtr ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
||||||
|
// preserves values.
|
||||||
|
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
||||||
|
// Get test parameters
|
||||||
|
auto& test_params = GetParam();
|
||||||
|
DataType dtype = std::get<0>(test_params);
|
||||||
|
TensorShape shape(std::get<1>(test_params));
|
||||||
|
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
||||||
|
|
||||||
|
// Construct a Tensor with the given dtype + shape
|
||||||
|
Tensor expected(dtype, shape);
|
||||||
|
testing::FillNumericTensorBuffer(expected.dtype(), expected.NumElements(),
|
||||||
|
expected.data(), 42);
|
||||||
|
|
||||||
|
// Serialize it to a Tensorproto
|
||||||
|
TensorProto proto;
|
||||||
|
if (tensorproto_use_tensor_content) {
|
||||||
|
expected.AsProtoTensorContent(&proto);
|
||||||
|
} else {
|
||||||
|
expected.AsProtoField(&proto);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revival should succeed w/o errors
|
||||||
|
std::unique_ptr<Constant> revived;
|
||||||
|
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
||||||
|
|
||||||
|
// The revived tensorhandle should have the exact same dtype, shape, +
|
||||||
|
// approx equivalent data to the original.
|
||||||
|
ImmediateExecutionTensorHandle* handle = revived->handle();
|
||||||
|
Status status;
|
||||||
|
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
||||||
|
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
||||||
|
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
||||||
|
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
||||||
|
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
||||||
|
for (int i = 0; i < expected.dims(); ++i) {
|
||||||
|
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
testing::CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
||||||
|
revived_tensor->Data(), expected.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test against combinations of tensors that are
|
||||||
|
// 1. Varying dtypes
|
||||||
|
// 2. Varying shapes
|
||||||
|
// 3. TensorProto serialized using tensor_content vs repeated type
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
ConstantIntegerDtypesTest, ConstantTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
|
||||||
|
::testing::ValuesIn(testing::InterestingShapes()),
|
||||||
|
::testing::Values(false, true)));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
ConstantFloatingDtypesTest, ConstantTest,
|
||||||
|
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||||
|
::testing::ValuesIn(testing::InterestingShapes()),
|
||||||
|
::testing::Values(false, true)));
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -28,6 +28,27 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "variable",
|
||||||
|
srcs = [
|
||||||
|
"variable.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"variable.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensorhandle_convertible",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core/ops:variable_ops",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorhandle_convertible",
|
name = "tensorhandle_convertible",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Variable::Variable(ImmediateExecutionContext* ctx, DataType dtype,
|
||||||
|
TensorShape shape, absl::optional<std::string> name,
|
||||||
|
ImmediateTensorHandlePtr handle)
|
||||||
|
: TensorHandleConvertible(std::move(handle)),
|
||||||
|
name_(name.has_value() ? *name : "Variable"),
|
||||||
|
dtype_(dtype),
|
||||||
|
shape_(shape),
|
||||||
|
ctx_(ctx) {}
|
||||||
|
|
||||||
|
Variable::~Variable() {
|
||||||
|
// If the handle is null (perhaps because variable was std::moved from), then
|
||||||
|
// we don't have to do anything.
|
||||||
|
if (handle_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status status = internal::DestroyResource(ctx_, handle_.get());
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Error destroying variable: " << name_
|
||||||
|
<< "due to: " << status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType Variable::dtype() { return dtype_; }
|
||||||
|
|
||||||
|
TensorShape Variable::shape() { return shape_; }
|
||||||
|
|
||||||
|
Status Variable::Assign(ImmediateExecutionTensorHandle* handle) {
|
||||||
|
return internal::AssignVariable(ctx_, handle_.get(), dtype_, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
||||||
|
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||||
|
DataType dtype, TensorShape shape,
|
||||||
|
absl::optional<std::string> name,
|
||||||
|
std::unique_ptr<Variable>* output) {
|
||||||
|
ImmediateTensorHandlePtr handle;
|
||||||
|
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||||
|
ctx, dtype, shape, &handle));
|
||||||
|
|
||||||
|
output->reset(
|
||||||
|
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class Variable : public TensorHandleConvertible {
|
||||||
|
public:
|
||||||
|
// Creates an uninitialized resource variable. Note that a caller must
|
||||||
|
// call "assign" to associate a value with the variable.
|
||||||
|
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||||
|
DataType dtype, TensorShape shape,
|
||||||
|
absl::optional<std::string> name,
|
||||||
|
std::unique_ptr<Variable>* output);
|
||||||
|
|
||||||
|
// The dtype of the underlying variable.
|
||||||
|
DataType dtype();
|
||||||
|
|
||||||
|
// The shape of the underlying variable.
|
||||||
|
TensorShape shape();
|
||||||
|
|
||||||
|
// Updates the variable's contents with `handle`.
|
||||||
|
Status Assign(ImmediateExecutionTensorHandle* handle);
|
||||||
|
|
||||||
|
// Reads the value of the variable, and stores it in `out`
|
||||||
|
Status ReadValue(ImmediateTensorHandlePtr* out);
|
||||||
|
|
||||||
|
// Variable is movable, but not copyable.
|
||||||
|
Variable(Variable&& other) = default;
|
||||||
|
Variable& operator=(Variable&& other) = default;
|
||||||
|
|
||||||
|
~Variable() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Variable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||||
|
absl::optional<std::string> name, ImmediateTensorHandlePtr handle);
|
||||||
|
Variable(const Variable& variable) = delete;
|
||||||
|
Variable& operator=(const Variable&) = delete;
|
||||||
|
|
||||||
|
std::string name_;
|
||||||
|
DataType dtype_;
|
||||||
|
TensorShape shape_;
|
||||||
|
|
||||||
|
// ctx_ must outlive Variable.
|
||||||
|
ImmediateExecutionContext* ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
@ -15,8 +15,13 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
@ -34,5 +39,20 @@ Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
|||||||
return Constant::Create(ctx, &tensor_interface, output);
|
return Constant::Create(ctx, &tensor_interface, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This follows the python variable restoration logic:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
|
||||||
|
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||||
|
const SavedVariable& variable,
|
||||||
|
std::unique_ptr<Variable>* output) {
|
||||||
|
const std::string& name = variable.name();
|
||||||
|
tensorflow::TensorShape shape(variable.shape());
|
||||||
|
tensorflow::DataType dtype = variable.dtype();
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
|
||||||
|
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -21,7 +21,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
@ -33,6 +35,14 @@ Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
|||||||
const TensorProto& proto,
|
const TensorProto& proto,
|
||||||
std::unique_ptr<Constant>* output);
|
std::unique_ptr<Constant>* output);
|
||||||
|
|
||||||
|
// Creates a tensorflow::Variable from a SavedVariable. This is similar to the
|
||||||
|
// logic in:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
|
||||||
|
// Note that the caller **must assign a value** to the loaded variable.
|
||||||
|
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||||
|
const SavedVariable& variable,
|
||||||
|
std::unique_ptr<Variable>* output);
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -1,199 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
|
||||||
#include "tensorflow/c/tensor_interface.h"
|
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
|
||||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
|
||||||
// This is needed for GTest's ::testing::ValuesIn, since
|
|
||||||
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
|
||||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
|
||||||
std::vector<DataType> result;
|
|
||||||
result.reserve(set.size());
|
|
||||||
for (DataType dt : set) {
|
|
||||||
result.push_back(dt);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a vector of shapes intended to be "interesting" test cases.
|
|
||||||
std::vector<std::vector<int64>> InterestingShapes() {
|
|
||||||
std::vector<std::vector<int64>> interesting_shapes;
|
|
||||||
interesting_shapes.push_back({}); // Scalar
|
|
||||||
interesting_shapes.push_back({10}); // 1D Vector
|
|
||||||
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
|
||||||
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
|
||||||
return interesting_shapes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fills a numeric tensor with `value`.
|
|
||||||
void FillNumericTensor(Tensor* tensor, int8 value) {
|
|
||||||
switch (tensor->dtype()) {
|
|
||||||
#define CASE(type) \
|
|
||||||
case DataTypeToEnum<type>::value: { \
|
|
||||||
const auto& flattened = tensor->flat<type>(); \
|
|
||||||
for (int i = 0; i < tensor->NumElements(); ++i) { \
|
|
||||||
flattened(i) = value; \
|
|
||||||
} \
|
|
||||||
break; \
|
|
||||||
}
|
|
||||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
|
||||||
TF_CALL_double(CASE);
|
|
||||||
TF_CALL_float(CASE);
|
|
||||||
#undef CASE
|
|
||||||
default:
|
|
||||||
CHECK(false) << "Unsupported data type: "
|
|
||||||
<< DataTypeString(tensor->dtype());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
|
||||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
|
||||||
// underlying buffers are the same before calling this.
|
|
||||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
|
||||||
void* b) {
|
|
||||||
switch (dtype) {
|
|
||||||
#define CASE(type) \
|
|
||||||
case DataTypeToEnum<type>::value: { \
|
|
||||||
type* typed_a = static_cast<type*>(a); \
|
|
||||||
type* typed_b = static_cast<type*>(b); \
|
|
||||||
for (int64 i = 0; i < num_elements; ++i) { \
|
|
||||||
if (DataTypeIsFloating(dtype)) { \
|
|
||||||
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
|
||||||
} else { \
|
|
||||||
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
break; \
|
|
||||||
}
|
|
||||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
|
||||||
TF_CALL_double(CASE);
|
|
||||||
TF_CALL_float(CASE);
|
|
||||||
#undef CASE
|
|
||||||
default:
|
|
||||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class ConstantTest : public ::testing::TestWithParam<
|
|
||||||
std::tuple<DataType, std::vector<int64>, bool>> {
|
|
||||||
public:
|
|
||||||
ConstantTest()
|
|
||||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
|
||||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
|
||||||
ctx_(new EagerContext(
|
|
||||||
SessionOptions(),
|
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
|
||||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
|
||||||
/* async= */ false,
|
|
||||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
|
||||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
|
||||||
/* custom_kernel_creator= */ nullptr,
|
|
||||||
/* cluster_flr= */ nullptr)) {}
|
|
||||||
|
|
||||||
EagerContext* context() { return ctx_.get(); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
|
||||||
EagerContextPtr ctx_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
|
||||||
// preserves values.
|
|
||||||
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
|
||||||
// Get test parameters
|
|
||||||
auto& test_params = GetParam();
|
|
||||||
DataType dtype = std::get<0>(test_params);
|
|
||||||
TensorShape shape(std::get<1>(test_params));
|
|
||||||
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
|
||||||
|
|
||||||
// Construct a Tensor with the given dtype + shape
|
|
||||||
Tensor expected(dtype, shape);
|
|
||||||
FillNumericTensor(&expected, 42);
|
|
||||||
|
|
||||||
// Serialize it to a Tensorproto
|
|
||||||
TensorProto proto;
|
|
||||||
if (tensorproto_use_tensor_content) {
|
|
||||||
expected.AsProtoTensorContent(&proto);
|
|
||||||
} else {
|
|
||||||
expected.AsProtoField(&proto);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Revival should succeed w/o errors
|
|
||||||
std::unique_ptr<Constant> revived;
|
|
||||||
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
|
||||||
|
|
||||||
// The revived tensorhandle should have the exact same dtype, shape, +
|
|
||||||
// approx equivalent data to the original.
|
|
||||||
ImmediateExecutionTensorHandle* handle = revived->handle();
|
|
||||||
Status status;
|
|
||||||
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
|
||||||
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
|
||||||
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
|
||||||
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
|
||||||
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
|
||||||
for (int i = 0; i < expected.dims(); ++i) {
|
|
||||||
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
|
||||||
revived_tensor->Data(), expected.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test against combinations of tensors that are
|
|
||||||
// 1. Varying dtypes
|
|
||||||
// 2. Varying shapes
|
|
||||||
// 3. TensorProto serialized using tensor_content vs repeated type
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
|
||||||
ConstantIntegerDtypesTest, ConstantTest,
|
|
||||||
::testing::Combine(
|
|
||||||
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
|
|
||||||
::testing::ValuesIn(InterestingShapes()),
|
|
||||||
::testing::Values(false, true)));
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
|
||||||
ConstantFloatingDtypesTest, ConstantTest,
|
|
||||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
|
||||||
::testing::ValuesIn(InterestingShapes()),
|
|
||||||
::testing::Values(false, true)));
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace tensorflow
|
|
@ -0,0 +1,122 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||||
|
#include "tensorflow/c/tensor_interface.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class SavedVariableLoadingTest : public ::testing::TestWithParam<
|
||||||
|
std::tuple<DataType, std::vector<int64>>> {
|
||||||
|
public:
|
||||||
|
SavedVariableLoadingTest()
|
||||||
|
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||||
|
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||||
|
|
||||||
|
EagerContext* context() { return ctx_.get(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||||
|
EagerContextPtr ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sanity check that constructing a tensorflow::Variable from a SavedVariable
|
||||||
|
// 1. does not cause an error
|
||||||
|
// 2. preserves dtype and shape.
|
||||||
|
TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
|
||||||
|
auto& test_params = GetParam();
|
||||||
|
DataType dtype = std::get<0>(test_params);
|
||||||
|
TensorShape shape(std::get<1>(test_params));
|
||||||
|
|
||||||
|
SavedVariable saved_variable;
|
||||||
|
saved_variable.set_dtype(dtype);
|
||||||
|
shape.AsProto(saved_variable.mutable_shape());
|
||||||
|
|
||||||
|
std::unique_ptr<Variable> var;
|
||||||
|
TF_EXPECT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
|
||||||
|
EXPECT_EQ(var->dtype(), dtype);
|
||||||
|
EXPECT_EQ(var->shape(), shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assigning and reading values should yield
|
||||||
|
// consistent results.
|
||||||
|
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
||||||
|
auto& test_params = GetParam();
|
||||||
|
DataType dtype = std::get<0>(test_params);
|
||||||
|
std::vector<int64> shape_vector = std::get<1>(test_params);
|
||||||
|
TensorShape shape(shape_vector);
|
||||||
|
|
||||||
|
// Create the variable.
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<Variable> var;
|
||||||
|
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||||
|
absl::nullopt, &var));
|
||||||
|
|
||||||
|
// Create a TensorHandle
|
||||||
|
ImmediateTensorHandlePtr expected_handle =
|
||||||
|
testing::CreateTensorHandle(context(), dtype, shape_vector, 42);
|
||||||
|
AbstractTensorPtr expected_tensor(expected_handle->Resolve(&status));
|
||||||
|
TF_EXPECT_OK(status) << status.error_message();
|
||||||
|
|
||||||
|
// Assign the tensorhandle to the variable.
|
||||||
|
TF_EXPECT_OK(var->Assign(expected_handle.get()));
|
||||||
|
|
||||||
|
// Read back the value from the variable
|
||||||
|
ImmediateTensorHandlePtr output_handle;
|
||||||
|
TF_EXPECT_OK(var->ReadValue(&output_handle));
|
||||||
|
AbstractTensorPtr output_tensor(output_handle->Resolve(&status));
|
||||||
|
TF_EXPECT_OK(status) << status.error_message();
|
||||||
|
|
||||||
|
// Check that output_tensor == expected_tensor
|
||||||
|
EXPECT_EQ(output_tensor->Type(), expected_tensor->Type());
|
||||||
|
EXPECT_EQ(output_tensor->NumElements(), expected_tensor->NumElements());
|
||||||
|
testing::CheckBufferDataIsEqual(
|
||||||
|
output_tensor->Type(), output_tensor->NumElements(),
|
||||||
|
output_tensor->Data(), expected_tensor->Data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test against combinations of SavedVariables of
|
||||||
|
// 1. Varying dtypes
|
||||||
|
// 2. Varying shapes
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
SavedVariableIntegerDtypesTest, SavedVariableLoadingTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
|
||||||
|
::testing::ValuesIn(testing::InterestingShapes())));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
SavedVariableFloatingDtypesTest, SavedVariableLoadingTest,
|
||||||
|
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||||
|
::testing::ValuesIn(testing::InterestingShapes())));
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
143
tensorflow/c/experimental/saved_model/core/test_utils.cc
Normal file
143
tensorflow/c/experimental/saved_model/core/test_utils.cc
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/tensor_interface.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace testing {
|
||||||
|
|
||||||
|
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr() {
|
||||||
|
return std::make_unique<StaticDeviceMgr>(
|
||||||
|
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
|
||||||
|
return EagerContextPtr(new EagerContext(
|
||||||
|
SessionOptions(),
|
||||||
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
|
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||||
|
/* async= */ false,
|
||||||
|
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
|
||||||
|
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||||
|
/* custom_kernel_creator= */ nullptr,
|
||||||
|
/* cluster_flr= */ nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
||||||
|
std::vector<DataType> result;
|
||||||
|
result.reserve(set.size());
|
||||||
|
for (DataType dt : set) {
|
||||||
|
result.push_back(dt);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64>> InterestingShapes() {
|
||||||
|
std::vector<std::vector<int64>> interesting_shapes;
|
||||||
|
interesting_shapes.push_back({}); // Scalar
|
||||||
|
interesting_shapes.push_back({10}); // 1D Vector
|
||||||
|
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
||||||
|
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
||||||
|
return interesting_shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
|
||||||
|
DataType dtype,
|
||||||
|
absl::Span<const int64> shape,
|
||||||
|
int8 value) {
|
||||||
|
AbstractTensorPtr tensor(ctx->CreateTensor(dtype, shape));
|
||||||
|
CHECK_NE(tensor.get(), nullptr)
|
||||||
|
<< "Tensor creation failed for tensor of dtype: "
|
||||||
|
<< DataTypeString(dtype);
|
||||||
|
CHECK_EQ(tensor->Type(), dtype);
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
CHECK_EQ(tensor->Dim(i), shape[i]);
|
||||||
|
}
|
||||||
|
FillNumericTensorBuffer(tensor->Type(), tensor->NumElements(), tensor->Data(),
|
||||||
|
value);
|
||||||
|
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
|
||||||
|
CHECK_NE(handle.get(), nullptr);
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||||
|
int8 value) {
|
||||||
|
switch (dtype) {
|
||||||
|
#define CASE(type) \
|
||||||
|
case DataTypeToEnum<type>::value: { \
|
||||||
|
type* typed_buffer = static_cast<type*>(buffer); \
|
||||||
|
for (size_t i = 0; i < num_elements; ++i) { \
|
||||||
|
typed_buffer[i] = value; \
|
||||||
|
} \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||||
|
TF_CALL_double(CASE);
|
||||||
|
TF_CALL_float(CASE);
|
||||||
|
#undef CASE
|
||||||
|
default:
|
||||||
|
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||||
|
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||||
|
// underlying buffers are the same before calling this.
|
||||||
|
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||||
|
void* b) {
|
||||||
|
switch (dtype) {
|
||||||
|
#define CASE(type) \
|
||||||
|
case DataTypeToEnum<type>::value: { \
|
||||||
|
type* typed_a = static_cast<type*>(a); \
|
||||||
|
type* typed_b = static_cast<type*>(b); \
|
||||||
|
for (int64 i = 0; i < num_elements; ++i) { \
|
||||||
|
if (DataTypeIsFloating(dtype)) { \
|
||||||
|
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
||||||
|
} else { \
|
||||||
|
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||||
|
TF_CALL_double(CASE);
|
||||||
|
TF_CALL_float(CASE);
|
||||||
|
#undef CASE
|
||||||
|
default:
|
||||||
|
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace testing
|
||||||
|
} // namespace tensorflow
|
75
tensorflow/c/experimental/saved_model/core/test_utils.h
Normal file
75
tensorflow/c/experimental/saved_model/core/test_utils.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace testing {
|
||||||
|
|
||||||
|
// Creates a DeviceMgr suitable for local tests.
|
||||||
|
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr();
|
||||||
|
|
||||||
|
// Creates an EagerContext suitable for local tests. Does not take ownership
|
||||||
|
// of `device_mgr`.
|
||||||
|
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr);
|
||||||
|
|
||||||
|
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
||||||
|
// This is useful for tests using GTest's ::testing::ValuesIn, since
|
||||||
|
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
||||||
|
std::vector<DataType> DataTypeSetToVector(DataTypeSet set);
|
||||||
|
|
||||||
|
// Returns a vector of shapes intended to be "interesting" test cases.
|
||||||
|
// Currently, this returns scalar, 1D vector, 2D matrix, and a 4D tensor shapes
|
||||||
|
std::vector<std::vector<int64>> InterestingShapes();
|
||||||
|
|
||||||
|
// Returns a TensorHandle of `dtype` and `shape`, filled with `value`.
|
||||||
|
// `dtype` must be an integer dtype, float, or double.
|
||||||
|
// If a TensorHandle cannot be created successfully, this function will
|
||||||
|
// CHECK fail. This should only be used for testing purposes.
|
||||||
|
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
|
||||||
|
DataType dtype,
|
||||||
|
absl::Span<const int64> shape,
|
||||||
|
int8 value);
|
||||||
|
|
||||||
|
// Fills a numeric tensor's buffer with `value`.
|
||||||
|
// dtype must be any integer dtype, float or double.
|
||||||
|
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||||
|
int8 value);
|
||||||
|
|
||||||
|
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||||
|
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||||
|
// underlying buffers are the same before calling this.
|
||||||
|
// dtype must be any integer dtype, float, or double.
|
||||||
|
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||||
|
void* b);
|
||||||
|
|
||||||
|
} // namespace testing
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user