From 9b576164f19237a0584ca4757219e0837757bf98 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Mon, 6 Apr 2020 15:05:24 -0700 Subject: [PATCH] Add Tensor & TensorHandle C APIs taking a Context The existing TF_AllocateTensor & TFE_NewTensorHandle APIs do not take a TFE_Context which is undesirable as the TFE_Context indicates ownership of the tensor. Thus we add new APIs to super-seed the existing ones. PiperOrigin-RevId: 305126310 Change-Id: I9863ebc692d48875c61b79197ab418f29503a8c6 --- tensorflow/c/c_api_experimental_test.cc | 2 +- tensorflow/c/eager/BUILD | 3 + tensorflow/c/eager/c_api_debug_test.cc | 16 ++- tensorflow/c/eager/c_api_experimental.cc | 32 ++++++ tensorflow/c/eager/c_api_experimental.h | 17 +++ tensorflow/c/eager/c_api_experimental_test.cc | 4 +- tensorflow/c/eager/c_api_remote_test.cc | 16 +-- tensorflow/c/eager/c_api_test.cc | 101 ++++++++++-------- tensorflow/c/eager/c_api_test_util.cc | 86 +++++++-------- tensorflow/c/eager/c_api_test_util.h | 18 ++-- .../eager/c_api_unified_experimental_test.cc | 4 +- tensorflow/c/eager/context_interface.h | 24 +---- tensorflow/c/eager/custom_device_test.cc | 6 +- .../core/common_runtime/eager/context.cc | 46 +------- .../core/common_runtime/eager/context.h | 20 +--- tensorflow/python/lib/core/py_seq_tensor.cc | 18 ++-- 16 files changed, 210 insertions(+), 203 deletions(-) diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index dfc29226783..cfeba345f81 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -218,7 +218,7 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) { TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32); float five = 5.0; - TFE_TensorHandle* scalar = TestScalarTensorHandle(five); + TFE_TensorHandle* scalar = TestScalarTensorHandle(tfe_context_, five); TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); CheckOutputShapes(fill_op, diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f148e55d54b..4326b723f74 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -179,6 +179,8 @@ cc_library( "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", ], ) @@ -193,6 +195,7 @@ tf_cuda_library( ], deps = [ ":c_api", + ":c_api_experimental", "//tensorflow/c:c_test_util", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc index 4e987c745ec..eff594ed3ed 100644 --- a/tensorflow/c/eager/c_api_debug_test.cc +++ b/tensorflow/c/eager/c_api_debug_test.cc @@ -21,8 +21,13 @@ limitations under the License. #include "tensorflow/core/platform/test.h" TEST(CApiDebug, ScalarCPU) { - TFE_TensorHandle* h = TestScalarTensorHandle(1.0f); TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h = TestScalarTensorHandle(ctx, 1.0f); TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -30,12 +35,18 @@ TEST(CApiDebug, ScalarCPU) { TFE_DeleteTensorDebugInfo(debug_info); TFE_DeleteTensorHandle(h); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CApiDebug, 2DCPU) { - TFE_TensorHandle* h = TestMatrixTensorHandle3X2(); TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h = TestMatrixTensorHandle3X2(ctx); TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -46,5 +57,6 @@ TEST(CApiDebug, 2DCPU) { TFE_DeleteTensorDebugInfo(debug_info); TFE_DeleteTensorHandle(h); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 6e4ac19c3ce..b43af710c04 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" +#include + #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" @@ -600,3 +602,33 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, }; status->status = tensorflow::Status::OK(); } + +TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, + const int64_t* dims, int num_dims, + TF_Status* status) { + std::vector dimvec(num_dims); + for (int i = 0; i < num_dims; ++i) { + dimvec[i] = static_cast(dims[i]); + } + + if (ctx == nullptr || ctx->context == nullptr) { + status->status = tensorflow::errors::InvalidArgument("Invalid Context"); + return nullptr; + } + + tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor( + static_cast(dtype), dimvec); + + if (t == nullptr) { + status->status = + tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype); + return nullptr; + } + + return new TF_Tensor{t}; +} + +TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t, + TF_Status* status) { + return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)}; +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 45d15960a9f..0037f2e81c8 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -524,6 +524,23 @@ TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Allocate and return a new Tensor on the host. +// +// The caller must set the Tensor values by writing them to the pointer returned +// by TF_TensorData with length TF_TensorByteSize. +TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, + TF_DataType dtype, + const int64_t* dims, + int num_dims, + TF_Status* status); + +// Given a Tensor, wrap it with a TensorHandle +// +// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context. +// The context should be identical to that of the Tensor. +TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor( + TFE_Context* ctx, TF_Tensor* t, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 6e2063db35e..0c058398299 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -378,7 +378,7 @@ void Executor_MatMul_CPU(bool async) { TFE_Executor* executor = TFE_NewExecutor(async); TFE_ContextSetExecutorForThread(ctx, executor); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_TensorHandle* retvals[2] = {nullptr, nullptr}; int num_retvals = 2; @@ -423,7 +423,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TF_Tensor* m_data = TFE_TensorHandleResolve(m, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float* m_float = static_cast(TF_TensorData(m_data)); diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index a084795eef6..91d19280c4c 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -75,8 +75,8 @@ void TestRemoteExecute(bool async) { TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); - TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); const char remote_device_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; auto* h0_task1 = @@ -160,8 +160,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) { TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); - TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; @@ -267,8 +267,8 @@ void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { // Use large matrices so that RPCs don't return before we get a chance // to call TFE_DeleteContext. - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(); - TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx); const char remote_device_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; auto* h0_task1 = @@ -331,7 +331,7 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, const char* remote_device_name, const char* local_device_name) { TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); TFE_OpSetDevice(matmul, remote_device_name, status); @@ -414,7 +414,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Create a new tensor_handle. - TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(); + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx); // Check that copying it to the old remote device (named localhost) fails. TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 6c4877b2ea2..55f1941ce89 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -49,7 +49,7 @@ void BM_InitOp(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -73,7 +73,7 @@ void BM_Execute(int iters, int async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_TensorHandle* retvals[1]; int num_retvals = 1; @@ -108,7 +108,7 @@ void BM_Execute_Identity(int iters, int async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* identity = IdentityOp(ctx, m); TFE_TensorHandle* retvals[1]; int num_retvals = 1; @@ -155,11 +155,16 @@ TEST(CAPI, Context) { } TEST(CAPI, TensorHandle) { - TFE_TensorHandle* h = TestMatrixTensorHandle(); - EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); - std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h = TestMatrixTensorHandle(ctx); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + TF_Tensor* t = TFE_TensorHandleResolve(h, status.get()); ASSERT_EQ(16, TF_TensorByteSize(t)); float data[4] = {0}; @@ -170,6 +175,7 @@ TEST(CAPI, TensorHandle) { EXPECT_EQ(4.0, data[3]); TF_DeleteTensor(t); TFE_DeleteTensorHandle(h); + TFE_DeleteContext(ctx); } void TensorHandleCopyBetweenDevices(bool async) { @@ -181,7 +187,7 @@ void TensorHandleCopyBetweenDevices(bool async) { TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -257,7 +263,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) { TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); const char* kErrorDevice = "NoSuchDevice:0"; TFE_TensorHandle* hdevice = TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get()); @@ -298,7 +304,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -384,7 +390,7 @@ void TensorHandleSilentCopy(bool async, TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); @@ -457,7 +463,7 @@ void SetAndGetOpDevices(bool async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, m, m); // Disable the test if no GPU is present. @@ -528,7 +534,7 @@ TEST(CAPI, TensorHandleDevices) { TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name; @@ -586,7 +592,7 @@ void ExecuteAdd(bool async, bool forward_input) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* n = TestMatrixTensorHandle100x100(); + TFE_TensorHandle* n = TestMatrixTensorHandle100x100(ctx); // If a GPU exists, copy the handle to GPU so that we can exercise // unprotecting a mirror. std::string gpu_device_name; @@ -598,7 +604,7 @@ void ExecuteAdd(bool async, bool forward_input) { n = n_gpu; } - TFE_TensorHandle* m = TestMatrixTensorHandle100x100(); + TFE_TensorHandle* m = TestMatrixTensorHandle100x100(ctx); // Store pointer to raw buffer for validation of forwarding behaviour. TF_Tensor* orig = TFE_TensorHandleResolve(n, status); @@ -670,7 +676,7 @@ void Execute_MatMul_CPU(bool async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_TensorHandle* retvals[2] = {nullptr, nullptr}; int num_retvals = 2; @@ -706,8 +712,8 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m1 = TestMatrixTensorHandle(); - TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(); + TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(ctx); TFE_Op* matmul = MatMulOp(ctx, m1, m2); TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", status); @@ -778,8 +784,8 @@ void Execute_MatMul_CPU_Type_Error(bool async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m1 = TestMatrixTensorHandle(); - TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(); + TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, m1, m2); TFE_TensorHandle* retvals[1] = {nullptr}; int num_retvals = 1; @@ -808,8 +814,8 @@ TEST(CAPI, Execute_Min_CPU) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* input = TestMatrixTensorHandle(); - TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_TensorHandle* input = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* axis = TestAxisTensorHandle(ctx); TFE_Op* minOp = MinOp(ctx, input, axis); TFE_TensorHandle* retvals[1] = {nullptr}; int num_retvals = 1; @@ -843,7 +849,7 @@ void Execute_MatMul_XLA_CPU(bool async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_OpSetXLACompilation(matmul, true); @@ -885,8 +891,8 @@ void Execute_Min_XLA_CPU(bool async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* input = TestMatrixTensorHandle(); - TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_TensorHandle* input = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* axis = TestAxisTensorHandle(ctx); TFE_Op* minOp = MinOp(ctx, input, axis); TFE_OpSetXLACompilation(minOp, true); @@ -926,7 +932,7 @@ void ExecuteWithTracing(bool async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_TensorHandle* retvals[1] = {nullptr}; int num_retvals = 1; @@ -1012,7 +1018,7 @@ void FunctionDefAndExecute(bool async) { if (clear_cache) { TFE_ContextClearCaches(ctx); } - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_TensorHandle* retval[1] = {nullptr}; int num_retvals = 1; TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status); @@ -1061,7 +1067,7 @@ void BM_ExecuteFunction(int iters, int async) { status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpAddInput(matmul, m, status); @@ -1276,11 +1282,15 @@ TEST(CAPI, StringAttributes) { } TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { - TFE_TensorHandle* h = TestMatrixTensorHandle(); - EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); - std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h = TestMatrixTensorHandle(ctx); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); TFE_TensorHandle* h_shares_tensor = TFE_TensorHandleCopySharingTensor(h, status.get()); @@ -1298,6 +1308,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { TFE_DeleteTensorHandle(h); TFE_DeleteTensorHandle(h_shares_tensor); + TFE_DeleteContext(ctx); } tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) { @@ -1315,8 +1326,8 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* input = TestMatrixTensorHandle(); - TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_TensorHandle* input = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* axis = TestAxisTensorHandle(ctx); TFE_Op* minOp = TFE_NewOp(ctx, "Min", status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpAddInput(minOp, input, status); @@ -1352,9 +1363,9 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* input1 = TestMatrixTensorHandle(); - TFE_TensorHandle* input2 = TestMatrixTensorHandle(); - TFE_TensorHandle* dim = TestScalarTensorHandle(0); + TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0); TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* inputs[] = {input1, input2}; @@ -1392,9 +1403,9 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* condition = TestScalarTensorHandle(true); - TFE_TensorHandle* t1 = TestMatrixTensorHandle(); - TFE_TensorHandle* t2 = TestAxisTensorHandle(); + TFE_TensorHandle* condition = TestScalarTensorHandle(ctx, true); + TFE_TensorHandle* t1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* t2 = TestAxisTensorHandle(ctx); TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpAddInput(assertOp, condition, status); @@ -1431,9 +1442,9 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* input1 = TestMatrixTensorHandle(); - TFE_TensorHandle* input2 = TestMatrixTensorHandle(); - TFE_TensorHandle* dim = TestScalarTensorHandle(0); + TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0); TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* inputs[] = {input1, input2}; @@ -1466,8 +1477,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* input1 = TestMatrixTensorHandle(); - TFE_TensorHandle* input2 = TestMatrixTensorHandle(); + TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx); TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -1514,8 +1525,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* input1 = TestMatrixTensorHandle(); - TFE_TensorHandle* input2 = TestMatrixTensorHandle(); + TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx); TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* inputs[] = {input1, input2}; diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index bee76fe296f..e67e17963b3 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -16,115 +16,117 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" using tensorflow::string; -TFE_TensorHandle* TestScalarTensorHandle(float value) { +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) { float data[] = {value}; - TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); return th; } -TFE_TensorHandle* TestScalarTensorHandle(int value) { +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) { int data[] = {value}; - TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); return th; } -TFE_TensorHandle* TestScalarTensorHandle(bool value) { +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) { bool data[] = {value}; - TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); return th; } -TFE_TensorHandle* DoubleTestMatrixTensorHandle() { +TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) { int64_t dims[] = {2, 2}; double data[] = {1.0, 2.0, 3.0, 4.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0], + sizeof(dims) / sizeof(int64_t), status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); return th; } -TFE_TensorHandle* TestMatrixTensorHandle() { +TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) { int64_t dims[] = {2, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], + sizeof(dims) / sizeof(int64_t), status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); return th; } -TFE_TensorHandle* TestMatrixTensorHandle100x100() { +TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) { constexpr int64_t dims[] = {100, 100}; constexpr int num_elements = dims[0] * dims[1]; float data[num_elements]; for (int i = 0; i < num_elements; ++i) { data[i] = 1.0f; } - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], + sizeof(dims) / sizeof(int64_t), status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); return th; } -TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) { int64_t dims[] = {3, 2}; double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], + sizeof(dims) / sizeof(int64_t), status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); return th; } -TFE_TensorHandle* TestMatrixTensorHandle3X2() { +TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) { int64_t dims[] = {3, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], + sizeof(dims) / sizeof(int64_t), status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); @@ -187,14 +189,14 @@ TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) { return op; } -TFE_TensorHandle* TestAxisTensorHandle() { +TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) { int64_t dims[] = {1}; int data[] = {1}; - TF_Tensor* t = TF_AllocateTensor( - TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], + sizeof(dims) / sizeof(int64_t), status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); TF_DeleteStatus(status); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 2c2f8323363..11ae6d1181b 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -19,28 +19,28 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // Return a tensor handle containing a float scalar -TFE_TensorHandle* TestScalarTensorHandle(float value); +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value); // Return a tensor handle containing a int scalar -TFE_TensorHandle* TestScalarTensorHandle(int value); +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value); // Return a tensor handle containing a bool scalar -TFE_TensorHandle* TestScalarTensorHandle(bool value); +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value); // Return a tensor handle containing a 2x2 matrix of doubles -TFE_TensorHandle* DoubleTestMatrixTensorHandle(); +TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx); // Return a tensor handle containing a 2x2 matrix of floats -TFE_TensorHandle* TestMatrixTensorHandle(); +TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx); // Return a tensor handle containing a 100x100 matrix of floats -TFE_TensorHandle* TestMatrixTensorHandle100x100(); +TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx); // Return a tensor handle containing a 3x2 matrix of doubles -TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx); // Return a tensor handle containing a 3x2 matrix of floats -TFE_TensorHandle* TestMatrixTensorHandle3X2(); +TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx); // Return an add op multiplying `a` by `b`. TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); @@ -55,7 +55,7 @@ TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a); TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a); // Return an 1-D INT32 tensor containing a single value 1. -TFE_TensorHandle* TestAxisTensorHandle(); +TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx); // Return an op taking minimum of `input` long `axis` dimension. TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index d9066464e8e..104ede9ebbd 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -46,7 +46,7 @@ TEST(UnifedCAPI, TestBasicEager) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract input tensor. - TFE_TensorHandle* t = TestScalarTensorHandle(2.0f); + TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f); TF_AbstractTensor* at = TF_NewAbstractTensor(); TF_AbstractTensorSetEagerTensor(at, t, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -162,7 +162,7 @@ TEST(UnifedCAPI, TestBasicGraph) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract input tensor. - TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f); + TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); TF_AbstractTensor* input_t = TF_NewAbstractTensor(); TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index e1779fdf73f..157f10c7fec 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -17,10 +17,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" @@ -40,7 +42,7 @@ class AbstractContextInterface { // destroy an instance of this class. virtual void Release() = 0; - // Scalar creation functions + // Optimized scalar creation functions virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0; virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0; virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0; @@ -52,24 +54,8 @@ class AbstractContextInterface { virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0; // Tensor creation functions - virtual AbstractTensorInterface* CreateInt64Tensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateUint64Tensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateInt32Tensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateFloatTensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateDoubleTensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateHalfTensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateStringTensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateComplex128Tensor( - absl::Span dim_sizes) = 0; - virtual AbstractTensorInterface* CreateBoolTensor( - absl::Span dim_sizes) = 0; + virtual AbstractTensorInterface* CreateTensor( + DataType dtype, absl::Span dim_sizes) = 0; // Create a handle to wrap and manage a Tensor virtual AbstractTensorHandleInterface* CreateLocalHandle( diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index b6e6369bb43..8f13c1e5151 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -156,7 +156,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; RegisterLoggingDevice(context, name, &arrived, &executed, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context); ASSERT_FALSE(arrived); TFE_TensorHandle* hdevice = TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get()); @@ -245,7 +245,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) { // Assign to the variable, copying to the custom device. std::unique_ptr one( - TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle); + TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle); op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get())); TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); TFE_OpAddInput(op.get(), var_handle, status.get()); @@ -331,7 +331,7 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) { // Assign to the variable, copying to the custom device. std::unique_ptr one( - TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle); + TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle); op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get())); TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); TFE_OpAddInput(op.get(), var_handle, status.get()); diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 8685e207a45..b0dd43e0256 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -161,49 +161,9 @@ AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) { return new TensorInterface(Tensor(value)); } -AbstractTensorInterface* EagerContext::CreateInt64Tensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_INT64, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateUint64Tensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_UINT64, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateInt32Tensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_INT32, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateFloatTensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_FLOAT, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateDoubleTensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_DOUBLE, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateHalfTensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_HALF, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateStringTensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_STRING, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateComplex128Tensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_COMPLEX128, TensorShape(dim_sizes))); -} - -AbstractTensorInterface* EagerContext::CreateBoolTensor( - absl::Span dim_sizes) { - return new TensorInterface(Tensor(DT_BOOL, TensorShape(dim_sizes))); +AbstractTensorInterface* EagerContext::CreateTensor( + DataType dtype, absl::Span dim_sizes) { + return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes))); } void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env, diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index ffe5a7b6d1a..877d8072008 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -160,24 +160,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { tensorflow::complex128 value) override; AbstractTensorInterface* CreateBoolScalar(bool value) override; - AbstractTensorInterface* CreateInt64Tensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateUint64Tensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateInt32Tensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateFloatTensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateDoubleTensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateHalfTensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateStringTensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateComplex128Tensor( - absl::Span dim_sizes) override; - AbstractTensorInterface* CreateBoolTensor( - absl::Span dim_sizes) override; + AbstractTensorInterface* CreateTensor( + DataType dtype, absl::Span dim_sizes) override; AbstractTensorHandleInterface* CreateLocalHandle( AbstractTensorInterface* t) override; diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 6403ca3a0ea..f05afeb22e5 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -321,7 +321,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateInt64Tensor(dim_sizes); + return ctx->context->CreateTensor(DT_INT64, dim_sizes); } static const char* ConvertScalar(PyObject* v, int64* out) { @@ -361,7 +361,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateUint64Tensor(dim_sizes); + return ctx->context->CreateTensor(DT_UINT64, dim_sizes); } static const char* ConvertScalar(PyObject* v, uint64* out) { @@ -398,7 +398,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateInt32Tensor(dim_sizes); + return ctx->context->CreateTensor(DT_INT32, dim_sizes); } static const char* ConvertScalar(PyObject* v, int32* out) { @@ -505,7 +505,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateFloatTensor(dim_sizes); + return ctx->context->CreateTensor(DT_FLOAT, dim_sizes); } static const char* ConvertScalar(PyObject* v, float* out) { @@ -521,7 +521,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateDoubleTensor(dim_sizes); + return ctx->context->CreateTensor(DT_DOUBLE, dim_sizes); } static const char* ConvertScalar(PyObject* v, double* out) { @@ -541,7 +541,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateHalfTensor(dim_sizes); + return ctx->context->CreateTensor(DT_HALF, dim_sizes); } static const char* ConvertScalar(PyObject* v, Eigen::half* out) { @@ -562,7 +562,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateStringTensor(dim_sizes); + return ctx->context->CreateTensor(DT_STRING, dim_sizes); } static const char* ConvertScalar(PyObject* v, tstring* out) { @@ -629,7 +629,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateComplex128Tensor(dim_sizes); + return ctx->context->CreateTensor(DT_COMPLEX128, dim_sizes); } static const char* ConvertScalar(PyObject* v, complex128* out) { @@ -657,7 +657,7 @@ struct ConverterTraits { static AbstractTensorInterface* CreateTensor( TFE_Context* ctx, absl::Span dim_sizes) { - return ctx->context->CreateBoolTensor(dim_sizes); + return ctx->context->CreateTensor(DT_BOOL, dim_sizes); } static const char* ConvertScalar(PyObject* v, bool* out) {