Adding a FromBuffer method to construct tensorflow::cc::Tensors from user provided buffers. This also unblocks adding unit tests to tensorflow::cc::Tensor.
PiperOrigin-RevId: 310575934 Change-Id: I2da02d2f26cb3337e3e2c207bf2718c0f6c0b48a
This commit is contained in:
parent
631d5bc133
commit
5909adaa16
@ -57,6 +57,7 @@ cc_library(
|
||||
"tensor.h",
|
||||
],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
],
|
||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cc {
|
||||
@ -30,19 +33,38 @@ namespace cc {
|
||||
// Tensor represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
// TODO(bmzhao): Add a factory function that constructs a Tensor from a char
|
||||
// buffer, with an options struct (to specify the buffer's layout, device?,
|
||||
// whether to create a TFRT or TF tensor, whether we should take ownership of
|
||||
// the memory, etc). This requires extending TF_NewTensor with an options
|
||||
// struct:
|
||||
// https://github.com/tensorflow/tensorflow/blob/3c520614a3c056d56afdc79b59979b9b0087f8b9/tensorflow/c/tf_tensor.h#L77-L80
|
||||
using DeleterCallback = std::function<void(void*, size_t)>;
|
||||
|
||||
// Constructs a Tensor from user provided buffer.
|
||||
//
|
||||
// Params:
|
||||
// dtype - The dtype of the tensor's data.
|
||||
// shape - A shape vector, where each element corresponds to the size of
|
||||
// the tensor's corresponding dimension.
|
||||
// data - Pointer to a buffer of memory to construct a Tensor out of.
|
||||
// len - The length (in bytes) of `data`
|
||||
// deleter - A std::function to be called when the Tensor no longer needs the
|
||||
// memory in `data`. This can be used to free `data`, or
|
||||
// perhaps decrement a refcount associated with `data`, etc.
|
||||
// status - Set to OK on success and an error on failure.
|
||||
// Returns:
|
||||
// If an error occurred, status->ok() will be false, and the returned
|
||||
// Tensor must not be used.
|
||||
// TODO(bmzhao): Add Runtime as an argument to this function so we can swap to
|
||||
// a TFRT backed tensor.
|
||||
// TODO(bmzhao): Add benchmarks on overhead for this function; we can
|
||||
// consider using int64_t* + length rather than vector.
|
||||
static Tensor FromBuffer(TF_DataType dtype, const std::vector<int64_t>& shape,
|
||||
void* data, size_t len, DeleterCallback deleter,
|
||||
Status* status);
|
||||
|
||||
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
|
||||
// we should offer a way to deep copy the tensor into a new tensor, which
|
||||
// owns the underlying memory. This could be a .deepcopy()/clone() method.
|
||||
|
||||
// TODO(bmzhao): In the future, we want to relax the non-copyability
|
||||
// constraint. To do so, we can add a C API function that acts like CopyFrom:
|
||||
// constraint. To do so, we can add a C API function that acts like
|
||||
// CopyFrom:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
|
||||
|
||||
// Tensor is movable, but not copyable
|
||||
@ -85,6 +107,16 @@ class Tensor {
|
||||
// This object retains ownership of the pointer.
|
||||
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
|
||||
|
||||
struct DeleterStruct {
|
||||
std::function<void(void*, size_t)> deleter;
|
||||
};
|
||||
|
||||
static void DeleterFunction(void* memory, size_t len, void* deleter_struct) {
|
||||
DeleterStruct* deleter = reinterpret_cast<DeleterStruct*>(deleter_struct);
|
||||
deleter->deleter(memory, len);
|
||||
delete deleter;
|
||||
}
|
||||
|
||||
struct TFTensorDeleter {
|
||||
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
|
||||
};
|
||||
@ -111,6 +143,30 @@ inline size_t Tensor::num_bytes() const {
|
||||
return TF_TensorByteSize(tensor_.get());
|
||||
}
|
||||
|
||||
inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
const std::vector<int64_t>& shape, void* data,
|
||||
size_t len, DeleterCallback deleter,
|
||||
Status* status) {
|
||||
// Credit to apassos@ for this technique:
|
||||
// Despite the fact that our API takes a std::function deleter, we are able
|
||||
// to maintain ABI stability because:
|
||||
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
|
||||
// 2. DeleterFunction is defined in the same build artifact that constructed
|
||||
// the std::function (so there isn't confusion about std::function ABI).
|
||||
// Note that 2. is satisifed by the fact that this is a header-only API, where
|
||||
// the function implementations are inline.
|
||||
|
||||
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
|
||||
TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len,
|
||||
&DeleterFunction, deleter_struct);
|
||||
if (tensor == nullptr) {
|
||||
status->SetStatus(TF_INVALID_ARGUMENT,
|
||||
"Failed to create tensor for input buffer");
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace tensorflow
|
||||
|
||||
|
21
tensorflow/cc/experimental/base/tests/BUILD
Normal file
21
tensorflow/cc/experimental/base/tests/BUILD
Normal file
@ -0,0 +1,21 @@
|
||||
# Tests for the C++ header-only base types.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensor_test",
|
||||
srcs = [
|
||||
"tensor_test.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
212
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
212
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
@ -0,0 +1,212 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
using SimpleTypes =
|
||||
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
|
||||
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
cc::Tensor tensor =
|
||||
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
bool done = false;
|
||||
cc::Status status;
|
||||
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||
{
|
||||
// data_vector is a rank 1 tensor.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(data_vector.size());
|
||||
|
||||
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
done = true;
|
||||
};
|
||||
|
||||
cc::Tensor tensor =
|
||||
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
/*data=*/data_vector.data(),
|
||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||
/*deleter=*/callback, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
}
|
||||
// At this point, tensor has been destroyed, and the deleter callback should
|
||||
// have run.
|
||||
EXPECT_TRUE(done);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user