Merge pull request #41713 from dnguyen28061:allocate_temp
PiperOrigin-RevId: 326305159 Change-Id: Ic95dcb86e7bc58ced1666c7adcbddb9cc99dd2f4
This commit is contained in:
commit
efa82dd9a7
@ -23,6 +23,7 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api.h",
|
"c_api.h",
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
|
"c_api_macros.h",
|
||||||
"tensor_interface.h",
|
"tensor_interface.h",
|
||||||
"tf_attrtype.h",
|
"tf_attrtype.h",
|
||||||
"tf_datatype.h",
|
"tf_datatype.h",
|
||||||
@ -61,6 +62,7 @@ filegroup(
|
|||||||
name = "pywrap_required_hdrs",
|
name = "pywrap_required_hdrs",
|
||||||
srcs = [
|
srcs = [
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"c_api_macros.h",
|
||||||
"conversion_macros.h",
|
"conversion_macros.h",
|
||||||
"python_api.h",
|
"python_api.h",
|
||||||
"tensor_interface.h",
|
"tensor_interface.h",
|
||||||
@ -79,6 +81,7 @@ tf_cuda_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"c_api.h",
|
"c_api.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"c_api_macros.h",
|
||||||
"tf_datatype.h",
|
"tf_datatype.h",
|
||||||
"tf_tensor.h",
|
"tf_tensor.h",
|
||||||
"tf_tstring.h",
|
"tf_tstring.h",
|
||||||
@ -310,6 +313,7 @@ cc_library(
|
|||||||
hdrs = ["tf_tensor.h"],
|
hdrs = ["tf_tensor.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":c_api_macros",
|
||||||
":tensor_interface",
|
":tensor_interface",
|
||||||
":tf_datatype",
|
":tf_datatype",
|
||||||
":tf_status",
|
":tf_status",
|
||||||
@ -336,6 +340,7 @@ tf_cuda_library(
|
|||||||
],
|
],
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":c_api_macros",
|
||||||
":tensor_interface",
|
":tensor_interface",
|
||||||
":tf_datatype",
|
":tf_datatype",
|
||||||
":tf_status",
|
":tf_status",
|
||||||
|
@ -30,4 +30,17 @@ limitations under the License.
|
|||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
#endif // SWIG
|
#endif // SWIG
|
||||||
|
|
||||||
|
// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is
|
||||||
|
// the datatype for boolean tensors.
|
||||||
|
#ifndef TF_Bool
|
||||||
|
#define TF_Bool unsigned char
|
||||||
|
#endif // TF_Bool
|
||||||
|
|
||||||
|
// Macro used to calculate struct size for maintaining ABI stability across
|
||||||
|
// different struct implementations.
|
||||||
|
#ifndef TF_OFFSET_OF_END
|
||||||
|
#define TF_OFFSET_OF_END(TYPE, MEMBER) \
|
||||||
|
(offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER))
|
||||||
|
#endif // TF_OFFSET_OF_END
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_C_API_MACROS_H_
|
#endif // TENSORFLOW_C_C_API_MACROS_H_
|
||||||
|
@ -261,7 +261,6 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
|
|||||||
size_t len, TF_Status* status) {
|
size_t len, TF_Status* status) {
|
||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
||||||
|
|
||||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||||
"64-bit int types should match in size");
|
"64-bit int types should match in size");
|
||||||
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
|
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
|
||||||
@ -280,3 +279,42 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
|
|||||||
}
|
}
|
||||||
return tf_tensor;
|
return tf_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
|
||||||
|
int64_t* dims, int num_dims,
|
||||||
|
TF_AllocatorAttributes* attributes,
|
||||||
|
TF_Status* status) {
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||||
|
"64-bit int types should match in size");
|
||||||
|
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
|
||||||
|
reinterpret_cast<tensorflow::int64*>(dims), num_dims);
|
||||||
|
if (attributes && !attributes->struct_size) {
|
||||||
|
TF_SetStatus(
|
||||||
|
status, TF_INVALID_ARGUMENT,
|
||||||
|
"TF_AllocatorAttributes struct "
|
||||||
|
"size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tensorflow::AllocatorAttributes allocator_attr;
|
||||||
|
if (attributes && attributes->on_host) {
|
||||||
|
allocator_attr.set_on_host(true);
|
||||||
|
}
|
||||||
|
tensorflow::Status s;
|
||||||
|
tensorflow::Tensor tensor;
|
||||||
|
s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
|
||||||
|
tensorflow::TensorShape(dimarray), &tensor,
|
||||||
|
allocator_attr);
|
||||||
|
if (!s.ok()) {
|
||||||
|
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TF_Tensor* tf_tensor;
|
||||||
|
tf_tensor = TF_TensorFromTensor(tensor, &s);
|
||||||
|
if (!s.ok()) {
|
||||||
|
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tf_tensor;
|
||||||
|
}
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
|
||||||
// Macro to control visibility of exported symbols in the shared library (.so,
|
// Macro to control visibility of exported symbols in the shared library (.so,
|
||||||
// .dylib, .dll).
|
// .dylib, .dll).
|
||||||
@ -199,6 +200,15 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
|
|||||||
int64_t* dims, int num_dims,
|
int64_t* dims, int num_dims,
|
||||||
size_t len, TF_Status* status);
|
size_t len, TF_Status* status);
|
||||||
|
|
||||||
|
// Allocates a temporary Tensor of the specified type and shape. The
|
||||||
|
// Tensor must not be used after kernel construction is
|
||||||
|
// complete.
|
||||||
|
//
|
||||||
|
// num_dims must equal the size of array dims
|
||||||
|
TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp(
|
||||||
|
TF_OpKernelContext* context, TF_DataType dtype, int64_t* dims, int num_dims,
|
||||||
|
TF_AllocatorAttributes* alloc_attrs, TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -368,6 +368,16 @@ class DeviceKernelOpTest : public OpsTestBase {
|
|||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Validates that the tensor has shape and type corresponding to
|
||||||
|
// dims and dtype.
|
||||||
|
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
|
||||||
|
TF_DataType dtype);
|
||||||
|
|
||||||
|
// Copies data of length tensor_size_bytes from values to tensor.
|
||||||
|
template <typename T>
|
||||||
|
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||||
|
TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||||
|
|
||||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||||
@ -379,22 +389,11 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
|||||||
TF_Tensor* output = TF_AllocateOutput(
|
TF_Tensor* output = TF_AllocateOutput(
|
||||||
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
||||||
/*num_dims=*/1, /*len=*/tensor_size_bytes, s);
|
/*num_dims=*/1, /*len=*/tensor_size_bytes, s);
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
validate_tensor(output, &dim, 1, TF_FLOAT);
|
||||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
|
||||||
EXPECT_EQ(1, TF_NumDims(output));
|
|
||||||
EXPECT_EQ(1, TF_Dim(output, 0));
|
|
||||||
|
|
||||||
// Set output to 3
|
// Set output to 3
|
||||||
float* data = reinterpret_cast<float*>(TF_TensorData(output));
|
float values[1] = {3.0f};
|
||||||
float value = 3.0f;
|
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
|
||||||
#if GOOGLE_CUDA
|
|
||||||
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
|
|
||||||
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value,
|
|
||||||
tensor_size_bytes);
|
|
||||||
#else
|
|
||||||
*data = value;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
TF_DeleteTensor(output);
|
TF_DeleteTensor(output);
|
||||||
};
|
};
|
||||||
@ -417,12 +416,8 @@ TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
|
|||||||
TF_Tensor* output = TF_AllocateOutput(
|
TF_Tensor* output = TF_AllocateOutput(
|
||||||
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
||||||
/*num_dims=*/1, /*len=*/0, s);
|
/*num_dims=*/1, /*len=*/0, s);
|
||||||
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
validate_tensor(output, &dim, 1, TF_FLOAT);
|
||||||
EXPECT_EQ(1, TF_NumDims(output));
|
|
||||||
EXPECT_EQ(0, TF_Dim(output, 0));
|
|
||||||
|
|
||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
TF_DeleteTensor(output);
|
TF_DeleteTensor(output);
|
||||||
};
|
};
|
||||||
@ -442,27 +437,16 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
|
|||||||
TF_Status* s = TF_NewStatus();
|
TF_Status* s = TF_NewStatus();
|
||||||
// Allocate 2x3 output
|
// Allocate 2x3 output
|
||||||
int64_t dim[2] = {2, 3};
|
int64_t dim[2] = {2, 3};
|
||||||
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
|
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6;
|
||||||
TF_Tensor* output = TF_AllocateOutput(
|
TF_Tensor* output = TF_AllocateOutput(
|
||||||
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
|
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
|
||||||
/*num_dims=*/2, /*len=*/tensor_size_bytes, s);
|
/*num_dims=*/2, /*len=*/tensor_size_bytes, s);
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
validate_tensor(output, dim, 2, TF_FLOAT);
|
||||||
EXPECT_EQ(2, TF_NumDims(output));
|
|
||||||
EXPECT_EQ(2, TF_Dim(output, 0));
|
|
||||||
EXPECT_EQ(3, TF_Dim(output, 1));
|
|
||||||
|
|
||||||
// Set output to [1 2 3 4 5 6]
|
// Set output to [1 2 3 4 5 6]
|
||||||
void* data = TF_TensorData(output);
|
float values[6] = {1, 2, 3, 4, 5, 6};
|
||||||
float value[6] = {1, 2, 3, 4, 5, 6};
|
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
|
||||||
#if GOOGLE_CUDA
|
|
||||||
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
|
|
||||||
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value,
|
|
||||||
tensor_size_bytes);
|
|
||||||
#else
|
|
||||||
memcpy(data, value, tensor_size_bytes);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
TF_DeleteTensor(output);
|
TF_DeleteTensor(output);
|
||||||
};
|
};
|
||||||
@ -474,4 +458,132 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
|
|||||||
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
|
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
|
||||||
output->DebugString(100));
|
output->DebugString(100));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
REGISTER_OP("AllocateTempOp1").Output("output1: float");
|
||||||
|
|
||||||
|
TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
|
||||||
|
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||||
|
// Allocate scalar TF_Tensor
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
int64_t dim = 1;
|
||||||
|
TF_AllocatorAttributes alloc_attrs;
|
||||||
|
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
alloc_attrs.on_host = 0;
|
||||||
|
#else
|
||||||
|
alloc_attrs.on_host = 1;
|
||||||
|
#endif
|
||||||
|
TF_Tensor* output = TF_AllocateTemp(
|
||||||
|
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
||||||
|
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
|
||||||
|
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||||
|
validate_tensor(output, &dim, 1, TF_FLOAT);
|
||||||
|
|
||||||
|
// Set TF_Tensor value to 3
|
||||||
|
float values[1] = {3.0f};
|
||||||
|
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
|
||||||
|
TF_SetOutput(ctx, 0, output, s);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
TF_DeleteTensor(output);
|
||||||
|
};
|
||||||
|
|
||||||
|
SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor* output = GetOutput(0);
|
||||||
|
EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
|
||||||
|
output->DebugString(100));
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_OP("AllocateTempOp0").Output("output1: float");
|
||||||
|
|
||||||
|
TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
|
||||||
|
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
// Allocate empty TF_Tensor
|
||||||
|
int64_t dim = 0;
|
||||||
|
TF_AllocatorAttributes alloc_attrs;
|
||||||
|
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
alloc_attrs.on_host = 0;
|
||||||
|
#else
|
||||||
|
alloc_attrs.on_host = 1;
|
||||||
|
#endif
|
||||||
|
TF_Tensor* output = TF_AllocateTemp(
|
||||||
|
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
||||||
|
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||||
|
validate_tensor(output, &dim, 1, TF_FLOAT);
|
||||||
|
TF_SetOutput(ctx, 0, output, s);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
TF_DeleteTensor(output);
|
||||||
|
};
|
||||||
|
|
||||||
|
SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor* output = GetOutput(0);
|
||||||
|
EXPECT_EQ("Tensor<type: float shape: [0] values: >",
|
||||||
|
output->DebugString(100));
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_OP("AllocateTempOp2x3").Output("output1: float");
|
||||||
|
|
||||||
|
TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
|
||||||
|
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
|
||||||
|
// Allocate 2x3 TF_Tensor
|
||||||
|
int64_t dim[2] = {2, 3};
|
||||||
|
TF_AllocatorAttributes alloc_attrs;
|
||||||
|
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
alloc_attrs.on_host = 0;
|
||||||
|
#else
|
||||||
|
alloc_attrs.on_host = 1;
|
||||||
|
#endif
|
||||||
|
TF_Tensor* output = TF_AllocateTemp(
|
||||||
|
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
|
||||||
|
/*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||||
|
validate_tensor(output, dim, 2, TF_FLOAT);
|
||||||
|
|
||||||
|
// Set TF_Tensor values to [1 2 3 4 5 6]
|
||||||
|
float values[6] = {1, 2, 3, 4, 5, 6};
|
||||||
|
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
|
||||||
|
TF_SetOutput(ctx, 0, output, s);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
TF_DeleteTensor(output);
|
||||||
|
};
|
||||||
|
|
||||||
|
SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor* output = GetOutput(0);
|
||||||
|
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
|
||||||
|
output->DebugString(100));
|
||||||
|
}
|
||||||
|
|
||||||
|
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
|
||||||
|
TF_DataType dtype) {
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
|
||||||
|
EXPECT_EQ(num_dims, TF_NumDims(tensor));
|
||||||
|
for (int i = 0; i < num_dims; ++i) {
|
||||||
|
EXPECT_EQ(dims[i], TF_Dim(tensor, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||||
|
TF_OpKernelContext* ctx) {
|
||||||
|
T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
|
||||||
|
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
|
||||||
|
tensor_size_bytes);
|
||||||
|
#else
|
||||||
|
memcpy(data, values, tensor_size_bytes);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
@ -45,6 +46,16 @@ limitations under the License.
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Allocator Attributes used for tensor allocation.
|
||||||
|
typedef struct TF_AllocatorAttributes {
|
||||||
|
size_t struct_size;
|
||||||
|
// Set boolean to 1 for CPU allocation, else 0.
|
||||||
|
TF_Bool on_host;
|
||||||
|
} TF_AllocatorAttributes;
|
||||||
|
|
||||||
|
#define TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE \
|
||||||
|
TF_OFFSET_OF_END(TF_AllocatorAttributes, on_host)
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// TF_Tensor holds a multi-dimensional array of elements of a single data type.
|
// TF_Tensor holds a multi-dimensional array of elements of a single data type.
|
||||||
// For all types other than TF_STRING, the data buffer stores elements
|
// For all types other than TF_STRING, the data buffer stores elements
|
||||||
|
Loading…
Reference in New Issue
Block a user