Initial checkin of C++ header-only TensorHandle as part of RFC https://github.com/tensorflow/community/pull/207.

PiperOrigin-RevId: 311179503
Change-Id: Ib3cfb2547150d09ee655db6ca6bc72ef3ef7adde
This commit is contained in:
Brian Zhao 2020-05-12 12:29:59 -07:00 committed by TensorFlower Gardener
parent 88dfd8ce6d
commit 5100abc4af
17 changed files with 462 additions and 89 deletions

View File

@ -924,7 +924,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
context->GetDevicePlacementPolicy());
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;

View File

@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
TF_Status* status);
// Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);

View File

@ -62,3 +62,17 @@ cc_library(
"//tensorflow/c:tf_tensor",
],
)
cc_library(
name = "tensorhandle",
hdrs = [
"tensorhandle.h",
],
deps = [
":runtime",
":status",
":tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
@ -40,6 +41,7 @@ class Runtime {
private:
friend class RuntimeBuilder;
friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
@ -63,6 +65,7 @@ class Runtime {
};
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Status is a wrapper around an error code and an optional error message.
@ -57,6 +58,7 @@ class Status {
friend class RuntimeBuilder;
friend class Runtime;
friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {}
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Tensor represents an n-dimensional array of values.
@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#include <memory>
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// An opaque representation of a tensor computed/managed by the Tensorflow
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
// to tensors placed in memory of different devices or remote address spaces.
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
// from it.
class TensorHandle {
public:
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
// status->ok() will be false, and the returned Tensor must not be used.
Tensor Resolve(Status* status);
// Constructs a TensorHandle from a Tensor. If an error occurred,
// status->ok() will be false, and the returned TensorHandle must not be used.
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
Status* status);
// TensorHandle is movable, and not copyable
TensorHandle(TensorHandle&&) = default;
TensorHandle& operator=(TensorHandle&&) = default;
private:
// Wraps a TFE_TensorHandle. Takes ownership of handle.
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
// TensorHandle is not copyable
TensorHandle(const TensorHandle&) = delete;
TensorHandle& operator=(const TensorHandle&) = delete;
// Returns the underlying TFE_TensorHandle that this object wraps.
// This object retains ownership of the pointer.
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
// and takes ownership of handle.
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
struct TFETensorHandleDeleter {
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
};
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
};
inline Tensor TensorHandle::Resolve(Status* status) {
TF_Tensor* tensor =
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
if (!status->ok()) {
return Tensor(nullptr);
}
return Tensor(tensor);
}
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
const Runtime& runtime,
Status* status) {
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
if (!status->ok()) {
return TensorHandle(nullptr);
}
return TensorHandle(tensor_handle);
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_

View File

@ -5,12 +5,22 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "tensor_types_test_util",
testonly = True,
hdrs = ["tensor_types_test_util.h"],
deps = [
"//tensorflow/c:tf_datatype",
],
)
tf_cc_test(
name = "tensor_test",
srcs = [
"tensor_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
@ -19,3 +29,22 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "tensorhandle_test",
srcs = [
"tensorhandle_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
"//tensorflow/cc/experimental/base/public:tensorhandle",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -16,69 +16,22 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include <stddef.h>
#include <cstdint>
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// Each of the following struct types have two members: a kDType that
// corresponds to a TF_Datatype enum value, and a typedef "type"
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
struct DoubleType {
using type = double;
static constexpr TF_DataType kDType = TF_DOUBLE;
};
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
using SimpleTypes =
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorTest : public ::testing::Test {};
@ -88,14 +41,13 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status;
Status status;
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
cc::Tensor tensor =
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status;
Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
std::vector<int64_t> shape;
shape.push_back(value.size());
cc::Tensor tensor = cc::Tensor::FromBuffer(
Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view(
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status;
Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
cc::Tensor tensor = cc::Tensor::FromBuffer(
Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view(
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
@ -185,22 +137,22 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
bool done = false;
cc::Status status;
Status status;
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
{
// data_vector is a rank 1 tensor.
std::vector<int64_t> shape;
shape.push_back(data_vector.size());
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
done = true;
};
cc::Tensor tensor =
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
/*data=*/data_vector.data(),
/*len=*/data_vector.size() * sizeof(int32_t),
/*deleter=*/callback, &status);
Tensor tensor =
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
/*data=*/data_vector.data(),
/*len=*/data_vector.size() * sizeof(int32_t),
/*deleter=*/callback, &status);
ASSERT_TRUE(status.ok()) << status.message();
}
// At this point, tensor has been destroyed, and the deleter callback should
@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
namespace tensorflow {
// Each of the following struct types have two members: a kDType that
// corresponds to a TF_Datatype enum value, and a typedef "type"
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
struct DoubleType {
using type = double;
static constexpr TF_DataType kDType = TF_DOUBLE;
};
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
using tensorflow::experimental::cc::TensorHandle;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
// verify the expected dims, dtype, value, num bytes, and num elements.
TYPED_TEST(ConstructScalarTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
Tensor original_tensor =
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
EXPECT_EQ(tensor.dtype(), dtype);
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), 1);
}
template <typename T>
class Construct1DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 1 vector.
std::vector<int64_t> shape;
shape.push_back(value.size());
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
template <typename T>
class Construct2DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
} // namespace
} // namespace tensorflow

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// FunctionMetadata stores additional function information, including
@ -40,6 +41,7 @@ class FunctionMetadata final {
};
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -26,10 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::SavedModelAPI;
using tensorflow::experimental::cc::Status;
constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
cc::Status status;
cc::RuntimeBuilder builder;
Status status;
RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"};
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
}
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
cc::Status status;
cc::RuntimeBuilder builder;
Status status;
RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
} // namespace
} // namespace tensorflow