Initial checkin of C++ header-only TensorHandle as part of RFC https://github.com/tensorflow/community/pull/207.

PiperOrigin-RevId: 311179503
Change-Id: Ib3cfb2547150d09ee655db6ca6bc72ef3ef7adde
This commit is contained in:
Brian Zhao 2020-05-12 12:29:59 -07:00 committed by TensorFlower Gardener
parent 88dfd8ce6d
commit 5100abc4af
17 changed files with 462 additions and 89 deletions

View File

@ -924,7 +924,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
context->GetDevicePlacementPolicy()); context->GetDevicePlacementPolicy());
} }
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor; tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor); status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;

View File

@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces. // placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle; typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
TF_Status* status); TF_Status* status);
// Indicates that the caller will not be using `h` any more. // Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);

View File

@ -62,3 +62,17 @@ cc_library(
"//tensorflow/c:tf_tensor", "//tensorflow/c:tf_tensor",
], ],
) )
cc_library(
name = "tensorhandle",
hdrs = [
"tensorhandle.h",
],
deps = [
":runtime",
":status",
":tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own // Runtime represents an opaque instance of a Tensorflow runtime, with its own
@ -40,6 +41,7 @@ class Runtime {
private: private:
friend class RuntimeBuilder; friend class RuntimeBuilder;
friend class SavedModelAPI; friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TFE_Context. Takes ownership of ctx. // Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
@ -63,6 +65,7 @@ class Runtime {
}; };
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h" #include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. // RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// Status is a wrapper around an error code and an optional error message. // Status is a wrapper around an error code and an optional error message.
@ -57,6 +58,7 @@ class Status {
friend class RuntimeBuilder; friend class RuntimeBuilder;
friend class Runtime; friend class Runtime;
friend class SavedModelAPI; friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TF_Status*, and takes ownership of it. // Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {} explicit Status(TF_Status* status) : status_(status) {}
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h" #include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// Tensor represents an n-dimensional array of values. // Tensor represents an n-dimensional array of values.
@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#include <memory>
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// An opaque representation of a tensor computed/managed by the Tensorflow
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
// to tensors placed in memory of different devices or remote address spaces.
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
// from it.
class TensorHandle {
public:
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
// status->ok() will be false, and the returned Tensor must not be used.
Tensor Resolve(Status* status);
// Constructs a TensorHandle from a Tensor. If an error occurred,
// status->ok() will be false, and the returned TensorHandle must not be used.
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
Status* status);
// TensorHandle is movable, and not copyable
TensorHandle(TensorHandle&&) = default;
TensorHandle& operator=(TensorHandle&&) = default;
private:
// Wraps a TFE_TensorHandle. Takes ownership of handle.
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
// TensorHandle is not copyable
TensorHandle(const TensorHandle&) = delete;
TensorHandle& operator=(const TensorHandle&) = delete;
// Returns the underlying TFE_TensorHandle that this object wraps.
// This object retains ownership of the pointer.
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
// and takes ownership of handle.
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
struct TFETensorHandleDeleter {
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
};
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
};
inline Tensor TensorHandle::Resolve(Status* status) {
TF_Tensor* tensor =
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
if (!status->ok()) {
return Tensor(nullptr);
}
return Tensor(tensor);
}
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
const Runtime& runtime,
Status* status) {
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
if (!status->ok()) {
return TensorHandle(nullptr);
}
return TensorHandle(tensor_handle);
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_

View File

@ -5,12 +5,22 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
cc_library(
name = "tensor_types_test_util",
testonly = True,
hdrs = ["tensor_types_test_util.h"],
deps = [
"//tensorflow/c:tf_datatype",
],
)
tf_cc_test( tf_cc_test(
name = "tensor_test", name = "tensor_test",
srcs = [ srcs = [
"tensor_test.cc", "tensor_test.cc",
], ],
deps = [ deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype", "//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:status", "//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor", "//tensorflow/cc/experimental/base/public:tensor",
@ -19,3 +29,22 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],
) )
tf_cc_test(
name = "tensorhandle_test",
srcs = [
"tensorhandle_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
"//tensorflow/cc/experimental/base/public:tensorhandle",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -16,69 +16,22 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/tensor.h" #include "tensorflow/cc/experimental/base/public/tensor.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <cstdint>
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace { namespace {
// Each of the following struct types have two members: a kDType that using tensorflow::experimental::cc::Status;
// corresponds to a TF_Datatype enum value, and a typedef "type" using tensorflow::experimental::cc::Tensor;
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
struct DoubleType { using SimpleTypes = ::testing::Types<
using type = double; tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
static constexpr TF_DataType kDType = TF_DOUBLE; tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
}; tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
using SimpleTypes =
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
template <typename T> template <typename T>
class ConstructScalarTensorTest : public ::testing::Test {}; class ConstructScalarTensorTest : public ::testing::Test {};
@ -88,11 +41,10 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and // and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements. // number of elements.
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) { TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status; Status status;
TF_DataType dtype = TypeParam::kDType; TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42; typename TypeParam::type value = 42;
cc::Tensor tensor = Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value, /*data=*/&value,
/*len=*/sizeof(value), /*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status); /*deleter=*/[](void*, size_t) {}, &status);
@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and // and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements. // number of elements.
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status; Status status;
TF_DataType dtype = TypeParam::kDType; TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype. // This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29}; std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
std::vector<int64_t> shape; std::vector<int64_t> shape;
shape.push_back(value.size()); shape.push_back(value.size());
cc::Tensor tensor = cc::Tensor::FromBuffer( Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape, /*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(), /*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type), /*len=*/value.size() * sizeof(typename TypeParam::type),
@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 1); EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype); EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view( tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size()); reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100); EXPECT_EQ(tensor_view[1], 100);
@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and // and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements. // number of elements.
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status; Status status;
TF_DataType dtype = TypeParam::kDType; TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype. // This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29}; std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3. // Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3}); std::vector<int64_t> shape({2, 3});
cc::Tensor tensor = cc::Tensor::FromBuffer( Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape, /*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(), /*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type), /*len=*/value.size() * sizeof(typename TypeParam::type),
@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 2); EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype); EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view( tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size()); reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100); EXPECT_EQ(tensor_view[1], 100);
@ -185,19 +137,19 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
TEST(CPPTensorAPI, ConstructTensorFromBuffer) { TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
bool done = false; bool done = false;
cc::Status status; Status status;
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100}); std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
{ {
// data_vector is a rank 1 tensor. // data_vector is a rank 1 tensor.
std::vector<int64_t> shape; std::vector<int64_t> shape;
shape.push_back(data_vector.size()); shape.push_back(data_vector.size());
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) { Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
done = true; done = true;
}; };
cc::Tensor tensor = Tensor tensor =
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape, Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
/*data=*/data_vector.data(), /*data=*/data_vector.data(),
/*len=*/data_vector.size() * sizeof(int32_t), /*len=*/data_vector.size() * sizeof(int32_t),
/*deleter=*/callback, &status); /*deleter=*/callback, &status);
@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
} }
} // namespace } // namespace
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
namespace tensorflow {
// Each of the following struct types have two members: a kDType that
// corresponds to a TF_Datatype enum value, and a typedef "type"
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
struct DoubleType {
using type = double;
static constexpr TF_DataType kDType = TF_DOUBLE;
};
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
using tensorflow::experimental::cc::TensorHandle;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
// verify the expected dims, dtype, value, num bytes, and num elements.
TYPED_TEST(ConstructScalarTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
Tensor original_tensor =
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
EXPECT_EQ(tensor.dtype(), dtype);
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), 1);
}
template <typename T>
class Construct1DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 1 vector.
std::vector<int64_t> shape;
shape.push_back(value.size());
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
template <typename T>
class Construct2DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
} // namespace
} // namespace tensorflow

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" #include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI. // ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of // ConcreteFunctionList helps convert an opaque pointer to an array of
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// FunctionMetadata stores additional function information, including // FunctionMetadata stores additional function information, including
@ -40,6 +41,7 @@ class FunctionMetadata final {
}; };
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models // SavedModelAPI offers a way to load Tensorflow Saved Models
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -26,10 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace { namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::SavedModelAPI;
using tensorflow::experimental::cc::Status;
constexpr char kTestData[] = "cc/saved_model/testdata"; constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {}; class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
cc::Status status; Status status;
cc::RuntimeBuilder builder; RuntimeBuilder builder;
bool use_tfrt = GetParam(); bool use_tfrt = GetParam();
if (use_tfrt) { if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
} }
builder.SetUseTFRT(use_tfrt); builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status); std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message(); ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"}; std::unordered_set<std::string> tags = {"serve"};
std::unique_ptr<cc::SavedModelAPI> model = std::unique_ptr<SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags); SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*, // That unblocks writing other tests that require a TF_SavedModel*,
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
} }
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
cc::Status status; Status status;
cc::RuntimeBuilder builder; RuntimeBuilder builder;
bool use_tfrt = GetParam(); bool use_tfrt = GetParam();
if (use_tfrt) { if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
} }
builder.SetUseTFRT(use_tfrt); builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status); std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message(); ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unique_ptr<cc::SavedModelAPI> model = std::unique_ptr<SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status); SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*, // That unblocks writing other tests that require a TF_SavedModel*,
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
} // namespace } // namespace
} // namespace tensorflow