Adding RevivedConstant class for Constant reloading in the SavedModelAPI C API.
PiperOrigin-RevId: 317920112 Change-Id: I2dc84de102c1edc5513df319e66ee20351bdb725
This commit is contained in:
parent
7d025c63c5
commit
95428e83f5
@ -3,6 +3,10 @@
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||
# have visibility limited to Tensorflow's implementation only.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -47,6 +51,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_utils",
|
||||
srcs = [
|
||||
"saved_model_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"saved_model_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
@ -84,3 +104,26 @@ filegroup(
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_utils_test",
|
||||
srcs = [
|
||||
"saved_model_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,39 @@
|
||||
# This package contains classes corresponding to Revived SavedObjectGraph types
|
||||
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
|
||||
package(
|
||||
default_visibility = [
|
||||
# Restricting visibility for now
|
||||
"//tensorflow/c/experimental/saved_model/core:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant",
|
||||
srcs = [
|
||||
"constant.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"constant.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_convertible",
|
||||
hdrs = [
|
||||
"tensorhandle_convertible.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Constant::Constant(ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)) {}
|
||||
|
||||
Status Constant::Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
|
||||
if (handle == nullptr) {
|
||||
return errors::Internal("Failed to convert tensor to tensorhandle");
|
||||
}
|
||||
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This class corresponds to python's tf.constant, which is effectively a
|
||||
// TensorHandle explicitly initialized to some value.
|
||||
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
|
||||
// and offer a subclass of TensorHandleConvertible. Note that similar to
|
||||
// the python's eager mode logic, we bypass calling the "Const" op:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
|
||||
class Constant : public TensorHandleConvertible {
|
||||
public:
|
||||
static Status Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
// RevivedConstant is movable, but not copyable.
|
||||
Constant(Constant&& other) = default;
|
||||
Constant& operator=(Constant&& other) = default;
|
||||
|
||||
~Constant() override = default;
|
||||
|
||||
private:
|
||||
explicit Constant(ImmediateTensorHandlePtr handle);
|
||||
Constant(const Constant&) = delete;
|
||||
Constant& operator=(const Constant&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A common interface for objects that can be converted to a TensorHandle.
|
||||
// Examples of objects that implement this include Variables, Constants, Assets,
|
||||
// etc. This is used to convert captured objects into a ConcreteFunction's
|
||||
// captured TensorHandles:
|
||||
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
|
||||
class TensorHandleConvertible {
|
||||
public:
|
||||
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
|
||||
: handle_(std::move(handle)) {}
|
||||
|
||||
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
|
||||
|
||||
// TensorHandleConvertible is movable, but not copyable.
|
||||
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
|
||||
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
|
||||
|
||||
virtual ~TensorHandleConvertible() = default;
|
||||
|
||||
protected:
|
||||
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
|
||||
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
|
||||
ImmediateTensorHandlePtr handle_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
tensorflow::Tensor tensor;
|
||||
bool parse_result = tensor.FromProto(proto);
|
||||
if (!parse_result) {
|
||||
return errors::Internal("Failed to parse tensor from tensorproto");
|
||||
}
|
||||
|
||||
TensorInterface tensor_interface(std::move(tensor));
|
||||
return Constant::Create(ctx, &tensor_interface, output);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
|
||||
// Some internal utility functions for the SavedModelAPI, factored out into a
|
||||
// separately unit-testable header.
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Load a TensorProto into a tensorflow::Constant. This is similar to the
|
||||
// constant loading logic in python:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
@ -0,0 +1,199 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
||||
// This is needed for GTest's ::testing::ValuesIn, since
|
||||
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
||||
std::vector<DataType> result;
|
||||
result.reserve(set.size());
|
||||
for (DataType dt : set) {
|
||||
result.push_back(dt);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a vector of shapes intended to be "interesting" test cases.
|
||||
std::vector<std::vector<int64>> InterestingShapes() {
|
||||
std::vector<std::vector<int64>> interesting_shapes;
|
||||
interesting_shapes.push_back({}); // Scalar
|
||||
interesting_shapes.push_back({10}); // 1D Vector
|
||||
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
||||
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
||||
return interesting_shapes;
|
||||
}
|
||||
|
||||
// Fills a numeric tensor with `value`.
|
||||
void FillNumericTensor(Tensor* tensor, int8 value) {
|
||||
switch (tensor->dtype()) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
const auto& flattened = tensor->flat<type>(); \
|
||||
for (int i = 0; i < tensor->NumElements(); ++i) { \
|
||||
flattened(i) = value; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: "
|
||||
<< DataTypeString(tensor->dtype());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_a = static_cast<type*>(a); \
|
||||
type* typed_b = static_cast<type*>(b); \
|
||||
for (int64 i = 0; i < num_elements; ++i) { \
|
||||
if (DataTypeIsFloating(dtype)) { \
|
||||
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
||||
} else { \
|
||||
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
class ConstantTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>, bool>> {
|
||||
public:
|
||||
ConstantTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
||||
// preserves values.
|
||||
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
||||
// Get test parameters
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
||||
|
||||
// Construct a Tensor with the given dtype + shape
|
||||
Tensor expected(dtype, shape);
|
||||
FillNumericTensor(&expected, 42);
|
||||
|
||||
// Serialize it to a Tensorproto
|
||||
TensorProto proto;
|
||||
if (tensorproto_use_tensor_content) {
|
||||
expected.AsProtoTensorContent(&proto);
|
||||
} else {
|
||||
expected.AsProtoField(&proto);
|
||||
}
|
||||
|
||||
// Revival should succeed w/o errors
|
||||
std::unique_ptr<Constant> revived;
|
||||
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
||||
|
||||
// The revived tensorhandle should have the exact same dtype, shape, +
|
||||
// approx equivalent data to the original.
|
||||
ImmediateExecutionTensorHandle* handle = revived->handle();
|
||||
Status status;
|
||||
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
||||
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
||||
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
||||
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
||||
for (int i = 0; i < expected.dims(); ++i) {
|
||||
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
||||
}
|
||||
|
||||
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
||||
revived_tensor->Data(), expected.data());
|
||||
}
|
||||
|
||||
// Test against combinations of tensors that are
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
// 3. TensorProto serialized using tensor_content vs repeated type
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantIntegerDtypesTest, ConstantTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantFloatingDtypesTest, ConstantTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user