Add Tensor & TensorHandle C APIs taking a Context
The existing TF_AllocateTensor & TFE_NewTensorHandle APIs do not take a TFE_Context which is undesirable as the TFE_Context indicates ownership of the tensor. Thus we add new APIs to super-seed the existing ones. PiperOrigin-RevId: 305126310 Change-Id: I9863ebc692d48875c61b79197ab418f29503a8c6
This commit is contained in:
parent
2f0ac02d72
commit
9b576164f1
@ -218,7 +218,7 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
|
||||
|
||||
float five = 5.0;
|
||||
TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
|
||||
TFE_TensorHandle* scalar = TestScalarTensorHandle(tfe_context_, five);
|
||||
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CheckOutputShapes(fill_op,
|
||||
|
@ -179,6 +179,8 @@ cc_library(
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -193,6 +195,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -21,8 +21,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
TEST(CApiDebug, ScalarCPU) {
|
||||
TFE_TensorHandle* h = TestScalarTensorHandle(1.0f);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestScalarTensorHandle(ctx, 1.0f);
|
||||
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
@ -30,12 +35,18 @@ TEST(CApiDebug, ScalarCPU) {
|
||||
|
||||
TFE_DeleteTensorDebugInfo(debug_info);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CApiDebug, 2DCPU) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle3X2();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle3X2(ctx);
|
||||
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
@ -46,5 +57,6 @@ TEST(CApiDebug, 2DCPU) {
|
||||
|
||||
TFE_DeleteTensorDebugInfo(debug_info);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
@ -600,3 +602,33 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
};
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims,
|
||||
TF_Status* status) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (ctx == nullptr || ctx->context == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dimvec);
|
||||
|
||||
if (t == nullptr) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new TF_Tensor{t};
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
TF_Status* status) {
|
||||
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
|
||||
}
|
||||
|
@ -524,6 +524,23 @@ TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Allocate and return a new Tensor on the host.
|
||||
//
|
||||
// The caller must set the Tensor values by writing them to the pointer returned
|
||||
// by TF_TensorData with length TF_TensorByteSize.
|
||||
TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
||||
TF_DataType dtype,
|
||||
const int64_t* dims,
|
||||
int num_dims,
|
||||
TF_Status* status);
|
||||
|
||||
// Given a Tensor, wrap it with a TensorHandle
|
||||
//
|
||||
// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context.
|
||||
// The context should be identical to that of the Tensor.
|
||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -378,7 +378,7 @@ void Executor_MatMul_CPU(bool async) {
|
||||
TFE_Executor* executor = TFE_NewExecutor(async);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
|
||||
int num_retvals = 2;
|
||||
@ -423,7 +423,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float* m_float = static_cast<float*>(TF_TensorData(m_data));
|
||||
|
@ -75,8 +75,8 @@ void TestRemoteExecute(bool async) {
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
@ -160,8 +160,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
@ -267,8 +267,8 @@ void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
@ -331,7 +331,7 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
const char* remote_device_name,
|
||||
const char* local_device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
@ -414,7 +414,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
|
||||
|
||||
// Check that copying it to the old remote device (named localhost) fails.
|
||||
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
|
||||
|
@ -49,7 +49,7 @@ void BM_InitOp(int iters) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
@ -73,7 +73,7 @@ void BM_Execute(int iters, int async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
@ -108,7 +108,7 @@ void BM_Execute_Identity(int iters, int async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identity = IdentityOp(ctx, m);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
@ -155,11 +155,16 @@ TEST(CAPI, Context) {
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandle) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
|
||||
ASSERT_EQ(16, TF_TensorByteSize(t));
|
||||
float data[4] = {0};
|
||||
@ -170,6 +175,7 @@ TEST(CAPI, TensorHandle) {
|
||||
EXPECT_EQ(4.0, data[3]);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
void TensorHandleCopyBetweenDevices(bool async) {
|
||||
@ -181,7 +187,7 @@ void TensorHandleCopyBetweenDevices(bool async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -257,7 +263,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) {
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
const char* kErrorDevice = "NoSuchDevice:0";
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
|
||||
@ -298,7 +304,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -384,7 +390,7 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -457,7 +463,7 @@ void SetAndGetOpDevices(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
@ -528,7 +534,7 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
|
||||
@ -586,7 +592,7 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* n = TestMatrixTensorHandle100x100(ctx);
|
||||
// If a GPU exists, copy the handle to GPU so that we can exercise
|
||||
// unprotecting a mirror.
|
||||
std::string gpu_device_name;
|
||||
@ -598,7 +604,7 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
n = n_gpu;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle100x100(ctx);
|
||||
|
||||
// Store pointer to raw buffer for validation of forwarding behaviour.
|
||||
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
|
||||
@ -670,7 +676,7 @@ void Execute_MatMul_CPU(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
|
||||
int num_retvals = 2;
|
||||
@ -706,8 +712,8 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2();
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
|
||||
TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
status);
|
||||
@ -778,8 +784,8 @@ void Execute_MatMul_CPU_Type_Error(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -808,8 +814,8 @@ TEST(CAPI, Execute_Min_CPU) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = MinOp(ctx, input, axis);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -843,7 +849,7 @@ void Execute_MatMul_XLA_CPU(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
|
||||
TFE_OpSetXLACompilation(matmul, true);
|
||||
@ -885,8 +891,8 @@ void Execute_Min_XLA_CPU(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = MinOp(ctx, input, axis);
|
||||
|
||||
TFE_OpSetXLACompilation(minOp, true);
|
||||
@ -926,7 +932,7 @@ void ExecuteWithTracing(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -1012,7 +1018,7 @@ void FunctionDefAndExecute(bool async) {
|
||||
if (clear_cache) {
|
||||
TFE_ContextClearCaches(ctx);
|
||||
}
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* retval[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
@ -1061,7 +1067,7 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
@ -1276,11 +1282,15 @@ TEST(CAPI, StringAttributes) {
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
TFE_TensorHandle* h_shares_tensor =
|
||||
TFE_TensorHandleCopySharingTensor(h, status.get());
|
||||
@ -1298,6 +1308,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteTensorHandle(h_shares_tensor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
@ -1315,8 +1326,8 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = TFE_NewOp(ctx, "Min", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(minOp, input, status);
|
||||
@ -1352,9 +1363,9 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
|
||||
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
@ -1392,9 +1403,9 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* condition = TestScalarTensorHandle(true);
|
||||
TFE_TensorHandle* t1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* t2 = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* condition = TestScalarTensorHandle(ctx, true);
|
||||
TFE_TensorHandle* t1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* t2 = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(assertOp, condition, status);
|
||||
@ -1431,9 +1442,9 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
|
||||
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
@ -1466,8 +1477,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
@ -1514,8 +1525,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
|
@ -16,115 +16,117 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(float value) {
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
float data[] = {value};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(int value) {
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
|
||||
int data[] = {value};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(bool value) {
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) {
|
||||
bool data[] = {value};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) {
|
||||
int64_t dims[] = {2, 2};
|
||||
double data[] = {1.0, 2.0, 3.0, 4.0};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle() {
|
||||
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
|
||||
int64_t dims[] = {2, 2};
|
||||
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100() {
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
|
||||
constexpr int64_t dims[] = {100, 100};
|
||||
constexpr int num_elements = dims[0] * dims[1];
|
||||
float data[num_elements];
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
data[i] = 1.0f;
|
||||
}
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
int64_t dims[] = {3, 2};
|
||||
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2() {
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
int64_t dims[] = {3, 2};
|
||||
float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
@ -187,14 +189,14 @@ TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
|
||||
return op;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestAxisTensorHandle() {
|
||||
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) {
|
||||
int64_t dims[] = {1};
|
||||
int data[] = {1};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
|
@ -19,28 +19,28 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
// Return a tensor handle containing a float scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(float value);
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
|
||||
|
||||
// Return a tensor handle containing a int scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(int value);
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
|
||||
|
||||
// Return a tensor handle containing a bool scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(bool value);
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
|
||||
|
||||
// Return a tensor handle containing a 2x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle();
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 2x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 100x100 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 3x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
@ -55,7 +55,7 @@ TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a);
|
||||
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
|
||||
|
||||
// Return an 1-D INT32 tensor containing a single value 1.
|
||||
TFE_TensorHandle* TestAxisTensorHandle();
|
||||
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx);
|
||||
|
||||
// Return an op taking minimum of `input` long `axis` dimension.
|
||||
TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
|
||||
|
@ -46,7 +46,7 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
@ -162,7 +162,7 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
@ -17,10 +17,12 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
|
||||
@ -40,7 +42,7 @@ class AbstractContextInterface {
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Scalar creation functions
|
||||
// Optimized scalar creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0;
|
||||
@ -52,24 +54,8 @@ class AbstractContextInterface {
|
||||
virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
|
||||
|
||||
// Tensor creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
|
@ -156,7 +156,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||
@ -245,7 +245,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
@ -331,7 +331,7 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
|
@ -161,49 +161,9 @@ AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
|
||||
return new TensorInterface(Tensor(value));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_INT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_UINT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_INT32, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_FLOAT, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_HALF, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_STRING, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(DT_BOOL, TensorShape(dim_sizes)));
|
||||
AbstractTensorInterface* EagerContext::CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) {
|
||||
return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
|
||||
|
@ -160,24 +160,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
tensorflow::complex128 value) override;
|
||||
AbstractTensorInterface* CreateBoolScalar(bool value) override;
|
||||
|
||||
AbstractTensorInterface* CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) override;
|
||||
|
||||
AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) override;
|
||||
|
@ -321,7 +321,7 @@ struct ConverterTraits<int64> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateInt64Tensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_INT64, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, int64* out) {
|
||||
@ -361,7 +361,7 @@ struct ConverterTraits<uint64> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateUint64Tensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_UINT64, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, uint64* out) {
|
||||
@ -398,7 +398,7 @@ struct ConverterTraits<int32> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateInt32Tensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_INT32, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, int32* out) {
|
||||
@ -505,7 +505,7 @@ struct ConverterTraits<float> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateFloatTensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_FLOAT, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, float* out) {
|
||||
@ -521,7 +521,7 @@ struct ConverterTraits<double> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateDoubleTensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_DOUBLE, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, double* out) {
|
||||
@ -541,7 +541,7 @@ struct ConverterTraits<Eigen::half> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateHalfTensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_HALF, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
|
||||
@ -562,7 +562,7 @@ struct ConverterTraits<tstring> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateStringTensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_STRING, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, tstring* out) {
|
||||
@ -629,7 +629,7 @@ struct ConverterTraits<complex128> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateComplex128Tensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_COMPLEX128, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, complex128* out) {
|
||||
@ -657,7 +657,7 @@ struct ConverterTraits<bool> {
|
||||
|
||||
static AbstractTensorInterface* CreateTensor(
|
||||
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
|
||||
return ctx->context->CreateBoolTensor(dim_sizes);
|
||||
return ctx->context->CreateTensor(DT_BOOL, dim_sizes);
|
||||
}
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, bool* out) {
|
||||
|
Loading…
Reference in New Issue
Block a user