Add Tensor & TensorHandle C APIs taking a Context

The existing TF_AllocateTensor & TFE_NewTensorHandle APIs do not take a
TFE_Context which is undesirable as the TFE_Context indicates ownership
of the tensor. Thus we add new APIs to super-seed the existing ones.

PiperOrigin-RevId: 305126310
Change-Id: I9863ebc692d48875c61b79197ab418f29503a8c6
This commit is contained in:
Gaurav Jain 2020-04-06 15:05:24 -07:00 committed by TensorFlower Gardener
parent 2f0ac02d72
commit 9b576164f1
16 changed files with 210 additions and 203 deletions

View File

@ -218,7 +218,7 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
float five = 5.0;
TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
TFE_TensorHandle* scalar = TestScalarTensorHandle(tfe_context_, five);
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CheckOutputShapes(fill_op,

View File

@ -179,6 +179,8 @@ cc_library(
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
@ -193,6 +195,7 @@ tf_cuda_library(
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -21,8 +21,13 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
TEST(CApiDebug, ScalarCPU) {
TFE_TensorHandle* h = TestScalarTensorHandle(1.0f);
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestScalarTensorHandle(ctx, 1.0f);
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -30,12 +35,18 @@ TEST(CApiDebug, ScalarCPU) {
TFE_DeleteTensorDebugInfo(debug_info);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CApiDebug, 2DCPU) {
TFE_TensorHandle* h = TestMatrixTensorHandle3X2();
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestMatrixTensorHandle3X2(ctx);
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -46,5 +57,6 @@ TEST(CApiDebug, 2DCPU) {
TFE_DeleteTensorDebugInfo(debug_info);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
@ -600,3 +602,33 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
};
status->status = tensorflow::Status::OK();
}
TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
const int64_t* dims, int num_dims,
TF_Status* status) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (ctx == nullptr || ctx->context == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
return nullptr;
}
tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec);
if (t == nullptr) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
return nullptr;
}
return new TF_Tensor{t};
}
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
TF_Status* status) {
return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)};
}

View File

@ -524,6 +524,23 @@ TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
// Allocate and return a new Tensor on the host.
//
// The caller must set the Tensor values by writing them to the pointer returned
// by TF_TensorData with length TF_TensorByteSize.
TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
TF_DataType dtype,
const int64_t* dims,
int num_dims,
TF_Status* status);
// Given a Tensor, wrap it with a TensorHandle
//
// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context.
// The context should be identical to that of the Tensor.
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -378,7 +378,7 @@ void Executor_MatMul_CPU(bool async) {
TFE_Executor* executor = TFE_NewExecutor(async);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
int num_retvals = 2;
@ -423,7 +423,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float* m_float = static_cast<float*>(TF_TensorData(m_data));

View File

@ -75,8 +75,8 @@ void TestRemoteExecute(bool async) {
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
@ -160,8 +160,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
@ -267,8 +267,8 @@ void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
@ -331,7 +331,7 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
@ -414,7 +414,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);

View File

@ -49,7 +49,7 @@ void BM_InitOp(int iters) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TFE_Op* matmul = MatMulOp(ctx, m, m);
@ -73,7 +73,7 @@ void BM_Execute(int iters, int async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
@ -108,7 +108,7 @@ void BM_Execute_Identity(int iters, int async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* identity = IdentityOp(ctx, m);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
@ -155,11 +155,16 @@ TEST(CAPI, Context) {
}
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(16, TF_TensorByteSize(t));
float data[4] = {0};
@ -170,6 +175,7 @@ TEST(CAPI, TensorHandle) {
EXPECT_EQ(4.0, data[3]);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
}
void TensorHandleCopyBetweenDevices(bool async) {
@ -181,7 +187,7 @@ void TensorHandleCopyBetweenDevices(bool async) {
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -257,7 +263,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) {
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* kErrorDevice = "NoSuchDevice:0";
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
@ -298,7 +304,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -384,7 +390,7 @@ void TensorHandleSilentCopy(bool async,
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
@ -457,7 +463,7 @@ void SetAndGetOpDevices(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
// Disable the test if no GPU is present.
@ -528,7 +534,7 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
@ -586,7 +592,7 @@ void ExecuteAdd(bool async, bool forward_input) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
TFE_TensorHandle* n = TestMatrixTensorHandle100x100(ctx);
// If a GPU exists, copy the handle to GPU so that we can exercise
// unprotecting a mirror.
std::string gpu_device_name;
@ -598,7 +604,7 @@ void ExecuteAdd(bool async, bool forward_input) {
n = n_gpu;
}
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
TFE_TensorHandle* m = TestMatrixTensorHandle100x100(ctx);
// Store pointer to raw buffer for validation of forwarding behaviour.
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
@ -670,7 +676,7 @@ void Execute_MatMul_CPU(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
int num_retvals = 2;
@ -706,8 +712,8 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2();
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(ctx);
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
status);
@ -778,8 +784,8 @@ void Execute_MatMul_CPU_Type_Error(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle();
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -808,8 +814,8 @@ TEST(CAPI, Execute_Min_CPU) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = MinOp(ctx, input, axis);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -843,7 +849,7 @@ void Execute_MatMul_XLA_CPU(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_OpSetXLACompilation(matmul, true);
@ -885,8 +891,8 @@ void Execute_Min_XLA_CPU(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = MinOp(ctx, input, axis);
TFE_OpSetXLACompilation(minOp, true);
@ -926,7 +932,7 @@ void ExecuteWithTracing(bool async) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -1012,7 +1018,7 @@ void FunctionDefAndExecute(bool async) {
if (clear_cache) {
TFE_ContextClearCaches(ctx);
}
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
@ -1061,7 +1067,7 @@ void BM_ExecuteFunction(int iters, int async) {
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(matmul, m, status);
@ -1276,11 +1282,15 @@ TEST(CAPI, StringAttributes) {
}
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
TFE_TensorHandle* h_shares_tensor =
TFE_TensorHandleCopySharingTensor(h, status.get());
@ -1298,6 +1308,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_DeleteTensorHandle(h);
TFE_DeleteTensorHandle(h_shares_tensor);
TFE_DeleteContext(ctx);
}
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
@ -1315,8 +1326,8 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = TFE_NewOp(ctx, "Min", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(minOp, input, status);
@ -1352,9 +1363,9 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
@ -1392,9 +1403,9 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* condition = TestScalarTensorHandle(true);
TFE_TensorHandle* t1 = TestMatrixTensorHandle();
TFE_TensorHandle* t2 = TestAxisTensorHandle();
TFE_TensorHandle* condition = TestScalarTensorHandle(ctx, true);
TFE_TensorHandle* t1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* t2 = TestAxisTensorHandle(ctx);
TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(assertOp, condition, status);
@ -1431,9 +1442,9 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
@ -1466,8 +1477,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -1514,8 +1525,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};

View File

@ -16,115 +16,117 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
using tensorflow::string;
TFE_TensorHandle* TestScalarTensorHandle(float value) {
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
float data[] = {value};
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(int value) {
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
int data[] = {value};
TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(bool value) {
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) {
bool data[] = {value};
TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {2, 2};
double data[] = {1.0, 2.0, 3.0, 4.0};
TF_Tensor* t = TF_AllocateTensor(
TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle() {
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100() {
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1];
float data[num_elements];
for (int i = 0; i < num_elements; ++i) {
data[i] = 1.0f;
}
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) {
int64_t dims[] = {3, 2};
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle3X2() {
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
int64_t dims[] = {3, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
@ -187,14 +189,14 @@ TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
return op;
}
TFE_TensorHandle* TestAxisTensorHandle() {
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {1};
int data[] = {1};
TF_Tensor* t = TF_AllocateTensor(
TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);

View File

@ -19,28 +19,28 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
// Return a tensor handle containing a float scalar
TFE_TensorHandle* TestScalarTensorHandle(float value);
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
// Return a tensor handle containing a int scalar
TFE_TensorHandle* TestScalarTensorHandle(int value);
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
// Return a tensor handle containing a bool scalar
TFE_TensorHandle* TestScalarTensorHandle(bool value);
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
// Return a tensor handle containing a 2x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle();
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing a 2x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle();
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100();
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2();
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
@ -55,7 +55,7 @@ TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a);
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
// Return an 1-D INT32 tensor containing a single value 1.
TFE_TensorHandle* TestAxisTensorHandle();
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx);
// Return an op taking minimum of `input` long `axis` dimension.
TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,

View File

@ -46,7 +46,7 @@ TEST(UnifedCAPI, TestBasicEager) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -162,7 +162,7 @@ TEST(UnifedCAPI, TestBasicGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());

View File

@ -17,10 +17,12 @@ limitations under the License.
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
@ -40,7 +42,7 @@ class AbstractContextInterface {
// destroy an instance of this class.
virtual void Release() = 0;
// Scalar creation functions
// Optimized scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0;
@ -52,24 +54,8 @@ class AbstractContextInterface {
virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
// Tensor creation functions
virtual AbstractTensorInterface* CreateInt64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateUint64Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateInt32Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateFloatTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateDoubleTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateHalfTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateStringTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateBoolTensor(
absl::Span<const int64> dim_sizes) = 0;
virtual AbstractTensorInterface* CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(

View File

@ -156,7 +156,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
@ -245,7 +245,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
@ -331,7 +331,7 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());

View File

@ -161,49 +161,9 @@ AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
return new TensorInterface(Tensor(value));
}
AbstractTensorInterface* EagerContext::CreateInt64Tensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_INT64, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateUint64Tensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_UINT64, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateInt32Tensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_INT32, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateFloatTensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_FLOAT, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateDoubleTensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateHalfTensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_HALF, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateStringTensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_STRING, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateBoolTensor(
absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(DT_BOOL, TensorShape(dim_sizes)));
AbstractTensorInterface* EagerContext::CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) {
return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
}
void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,

View File

@ -160,24 +160,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
tensorflow::complex128 value) override;
AbstractTensorInterface* CreateBoolScalar(bool value) override;
AbstractTensorInterface* CreateInt64Tensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateUint64Tensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateInt32Tensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateFloatTensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateDoubleTensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateHalfTensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateStringTensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateComplex128Tensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateBoolTensor(
absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) override;
AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) override;

View File

@ -321,7 +321,7 @@ struct ConverterTraits<int64> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateInt64Tensor(dim_sizes);
return ctx->context->CreateTensor(DT_INT64, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, int64* out) {
@ -361,7 +361,7 @@ struct ConverterTraits<uint64> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateUint64Tensor(dim_sizes);
return ctx->context->CreateTensor(DT_UINT64, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, uint64* out) {
@ -398,7 +398,7 @@ struct ConverterTraits<int32> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateInt32Tensor(dim_sizes);
return ctx->context->CreateTensor(DT_INT32, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, int32* out) {
@ -505,7 +505,7 @@ struct ConverterTraits<float> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateFloatTensor(dim_sizes);
return ctx->context->CreateTensor(DT_FLOAT, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, float* out) {
@ -521,7 +521,7 @@ struct ConverterTraits<double> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateDoubleTensor(dim_sizes);
return ctx->context->CreateTensor(DT_DOUBLE, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, double* out) {
@ -541,7 +541,7 @@ struct ConverterTraits<Eigen::half> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateHalfTensor(dim_sizes);
return ctx->context->CreateTensor(DT_HALF, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
@ -562,7 +562,7 @@ struct ConverterTraits<tstring> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateStringTensor(dim_sizes);
return ctx->context->CreateTensor(DT_STRING, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, tstring* out) {
@ -629,7 +629,7 @@ struct ConverterTraits<complex128> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateComplex128Tensor(dim_sizes);
return ctx->context->CreateTensor(DT_COMPLEX128, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, complex128* out) {
@ -657,7 +657,7 @@ struct ConverterTraits<bool> {
static AbstractTensorInterface* CreateTensor(
TFE_Context* ctx, absl::Span<const int64> dim_sizes) {
return ctx->context->CreateBoolTensor(dim_sizes);
return ctx->context->CreateTensor(DT_BOOL, dim_sizes);
}
static const char* ConvertScalar(PyObject* v, bool* out) {