From 69c0447f011bd5077abcd6078502c64e701d97a8 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 28 May 2020 18:26:21 -0700 Subject: [PATCH] ParallelDevice: Sync executors when returning non-parallel TensorHandles, add remote tests The actual delta isn't huge; I'm moving some test utilities to a testlib since the remote tests need them. The remote tests are in a different target because they need to disable global heap checking, which I'd like to keep on for the rest of the tests. PiperOrigin-RevId: 313698670 Change-Id: I846294a748e3b007eba0472901b0e58358b8edd5 --- tensorflow/c/eager/parallel_device/BUILD | 45 +- .../eager/parallel_device/parallel_device.cc | 5 + .../parallel_device_remote_test.cc | 147 +++++++ .../parallel_device/parallel_device_test.cc | 385 +----------------- .../parallel_device_testlib.cc | 308 ++++++++++++++ .../parallel_device/parallel_device_testlib.h | 174 ++++++++ 6 files changed, 677 insertions(+), 387 deletions(-) create mode 100644 tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc create mode 100644 tensorflow/c/eager/parallel_device/parallel_device_testlib.cc create mode 100644 tensorflow/c/eager/parallel_device/parallel_device_testlib.h diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 3b2640e14d1..6fce918aab1 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -39,9 +39,11 @@ cc_library( ], ) -tf_cc_test( - name = "parallel_device_test", - srcs = ["parallel_device_test.cc"], +cc_library( + name = "parallel_device_testlib", + testonly = 1, + srcs = ["parallel_device_testlib.cc"], + hdrs = ["parallel_device_testlib.h"], deps = [ ":parallel_device", ":parallel_device_ops", @@ -49,12 +51,49 @@ tf_cc_test( "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "parallel_device_test", + srcs = ["parallel_device_test.cc"], + deps = [ + ":parallel_device", + ":parallel_device_ops", + ":parallel_device_testlib", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) +tf_cc_test( + name = "parallel_device_remote_test", + srcs = ["parallel_device_remote_test.cc"], + # TODO(b/136478427): Enable global heap checking when servers shut down + # cleanly. + args = ["--heap_check=local"], + deps = [ + ":parallel_device", + ":parallel_device_ops", + ":parallel_device_testlib", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + ], +) + # Note: ParallelDevice-specific ops are experimental and not currently linked in # to TensorFlow by default, just used in a few tests. filegroup( diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index 27c2699c4c2..75d188d0c45 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -319,6 +319,11 @@ absl::optional> ParallelDevice::Execute( std::vector outputs; outputs.reserve(t->num_tensors()); for (int i = 0; i < t->num_tensors(); ++i) { + // TODO(b/157523095): Syncing the executor here shouldn't be + // necessary. Currently async+remote is missing cross-executor + // coordination. + TFE_ExecutorWaitForAllPendingNodes(executors_[i].get(), status); + if (TF_GetCode(status) != TF_OK) return result; TensorHandlePtr this_output( TFE_TensorHandleCopySharingTensor(t->tensor(i), status)); outputs.emplace_back(std::move(this_output)); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc new file mode 100644 index 00000000000..32a4b440d25 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device.h" +#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/platform/test.h" + +tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name(job_name); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name(job_name); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost", ":", port)}); + } + return server_def; +} + +TEST(PARALLEL_DEVICE, TestRemoteBasic) { + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + tensorflow::ServerDef server_def = GetServerDef("worker", 3); + + // This server def has the task index set to 0. + std::string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TFE_ContextSetServerDef(context.get(), 0, serialized.data(), + serialized.size(), status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + BasicTestsForTwoDevices(context.get(), + "/job:worker/replica:0/task:1/device:CPU:0", + "/job:worker/replica:0/task:2/device:CPU:0"); + + worker_server1.release(); + worker_server2.release(); +} + +TEST(PARALLEL_DEVICE, TestAsyncCopyOff) { + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + tensorflow::ServerDef server_def = GetServerDef("worker", 3); + + // This server def has the task index set to 0. + std::string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TFE_ContextSetServerDef(context.get(), 0, serialized.data(), + serialized.size(), status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0"; + const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0"; + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::array underlying_devices{first_device, second_device}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TensorHandlePtr value_one(FloatTensorHandle(3., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TensorHandlePtr value_two(FloatTensorHandle(-2., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array in_components{value_one.get(), + value_two.get()}; + TensorHandlePtr combined_value = CreatePerDeviceValues( + context.get(), in_components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Loop to make synchronization failures more deterministic + for (int i = 0; i < 100; ++i) { + TensorHandlePtr multiply_result( + Multiply(context.get(), combined_value.get(), combined_value.get(), + status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array out_components; + ExtractPerDeviceValues(context.get(), multiply_result.get(), + &out_components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(out_components[0].get(), 9.); + ExpectScalarEq(out_components[1].get(), 4.); + } + + worker_server1.release(); + worker_server2.release(); +} diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index fdc140407df..d9784ac9fa6 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" #include "tensorflow/core/platform/test.h" // NOTE(allenl): These tests currently go through TFE_Execute and so are @@ -28,390 +29,6 @@ limitations under the License. // correspond fairly well to the implementation, but testing the C++ directly is // another option. -// Functor for making unique_ptr to TFE_TensorHandle slightly more -// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second -// template argument requires passing a function pointer to -// TFE_DeleteTensorHandle when constructing the unique_ptr. -class TensorHandleDeleter { - public: - void operator()(TFE_TensorHandle* to_delete) { - TFE_DeleteTensorHandle(to_delete); - } -}; - -using TensorHandlePtr = std::unique_ptr; - -// A helper for performing common operations on variables. A much more -// restricted stand-in for tf.Variable in Python. -class Variable { - public: - // Construct a Variable from a resource-dtype TFE_TensorHandle and an - // indication of the dtype of the variable's value. - // - // Note that creating this resource-dtype handle can fail, so `Create` is a - // separate static method which returns a status. - Variable(TFE_TensorHandle* handle, TF_DataType type) - : handle_(handle), type_(type) {} - - // Helper for constructing a resource handle and wrapping it in a `Variable` - // object. - static Variable* Create(TFE_Context* context, TF_DataType type, - const int64_t* dims, const int num_dims, - const char* device, TF_Status* status); - // Dereferences the backing buffer for the variable. Note that since this can - // fail (it runs operations), it must be called explicitly and the resulting - // `status` checked. - void Destroy(TFE_Context* context, TF_Status* status); - - // Reads from the variable. - TensorHandlePtr Read(TFE_Context* context, TF_Status* status); - // Assigns a new value to the variable. - void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status); - // Adds `value` to the existing value of the variable. - void AssignAdd(TFE_Context* context, TFE_TensorHandle* value, - TF_Status* status); - - private: - // Helper for running any single-argument assignment ops (Assign, AssignAdd, - // AssignSub, ...). - void GeneralAssignment(const char* op_name, TFE_Context* context, - TFE_TensorHandle* value, TF_Status* status); - - // The a handle for the resource-dtype tensor pointing to the variable's - // buffer. - TFE_TensorHandle* handle_; - // The dtype of the variable's buffer (input dtype for assignments, output - // dtype of read operations). - TF_DataType type_; -}; - -Variable* Variable::Create(TFE_Context* context, TF_DataType type, - const int64_t* dims, const int num_dims, - const char* device, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op.get(), "dtype", type); - TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status); - TFE_OpSetAttrString(op.get(), "container", "", 0); - // Use the special GUID for no buffer sharing - // - // TODO(allenl): Should we provide a better API for this? AFAIK this is the - // only reasonable way to make variables with no aliasing using the eager C - // API. - std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; - TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(), - no_sharing.length()); - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_TensorHandle* var_handle = nullptr; - int num_retvals = 1; - TFE_Execute(op.get(), &var_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return new Variable(var_handle, type); -} - -void Variable::Destroy(TFE_Context* context, TF_Status* status) { - // Free the backing buffer for the variable. - std::unique_ptr op( - TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpAddInput(op.get(), handle_, status); - if (TF_GetCode(status) != TF_OK) return; - const char* device = TFE_TensorHandleDeviceName(handle_, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return; - int num_retvals = 0; - TFE_Execute(op.get(), nullptr, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return; - // Delete the variable handle itself. - TFE_DeleteTensorHandle(handle_); -} - -TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpAddInput(op.get(), handle_, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - const char* device = TFE_TensorHandleDeviceName(handle_, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op.get(), "dtype", type_); - int num_retvals = 1; - TFE_TensorHandle* var_value = nullptr; - TFE_Execute(op.get(), &var_value, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return TensorHandlePtr(var_value); -} - -void Variable::GeneralAssignment(const char* op_name, TFE_Context* context, - TFE_TensorHandle* value, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, op_name, status), &TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetAttrType(op.get(), "dtype", type_); - TFE_OpAddInput(op.get(), handle_, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpAddInput(op.get(), value, status); - if (TF_GetCode(status) != TF_OK) return; - const char* device = TFE_TensorHandleDeviceName(handle_, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetDevice(op.get(), device, status); - - int num_retvals = 0; - TFE_Execute(op.get(), nullptr, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return; -} - -void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value, - TF_Status* status) { - GeneralAssignment("AssignAddVariableOp", context, value, status); -} - -void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value, - TF_Status* status) { - GeneralAssignment("AssignVariableOp", context, value, status); -} - -// Passed to `TF_NewTensor` to indicate how an array of floats should be -// deleted. -static void FloatDeallocator(void* data, size_t, void* arg) { - delete[] static_cast(data); -} - -// Creates a TFE_TensorHandle with value `v`. -TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) { - const int num_bytes = sizeof(float); - float* values = new float[1]; - values[0] = v; - std::unique_ptr tensor( - TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator, - nullptr), - TF_DeleteTensor); - return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); -} - -// Creates a rank-one TFE_TensorHandle with value `v`. -TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, - TF_Status* status) { - const int num_bytes = v.size() * sizeof(float); - float* values = new float[v.size()]; - memcpy(values, v.data(), num_bytes); - int64_t dims = v.size(); - std::unique_ptr tensor( - TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes, - &FloatDeallocator, nullptr), - TF_DeleteTensor); - return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); -} - -// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. -template -void ExtractPerDeviceValues( - TFE_Context* context, TFE_TensorHandle* input, - std::array* components, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas); - TFE_OpAddInput(op.get(), input, status); - if (TF_GetCode(status) != TF_OK) return; - const char* device = TFE_TensorHandleDeviceName(input, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return; - - TFE_TensorHandle* result_handles[num_replicas]; - int num_retvals = num_replicas; - TFE_Execute(op.get(), result_handles, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return; - for (int i = 0; i < num_replicas; ++i) { - (*components)[i].reset(result_handles[i]); - } -} - -// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle. -template -TensorHandlePtr CreatePerDeviceValues( - TFE_Context* context, - const std::array& components, - const char* device, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrInt(op.get(), "N", num_replicas); - for (int i = 0; i < num_replicas; ++i) { - TFE_OpAddInput(op.get(), components[i], status); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - - TFE_TensorHandle* result_handle; - int num_retvals = 1; - TFE_Execute(op.get(), &result_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return TensorHandlePtr(result_handle); -} - -TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, - TFE_TensorHandle* second, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "Mul", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpAddInput(op.get(), first, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpAddInput(op.get(), second, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - const char* first_device = TFE_TensorHandleDeviceName(first, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetDevice(op.get(), first_device, status); - - TFE_TensorHandle* result_handle; - int num_retvals = 1; - TFE_Execute(op.get(), &result_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return TensorHandlePtr(result_handle); -} - -// Assert that `handle` is equal to `expected_value`. -template -void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - std::unique_ptr value_zero( - TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - EXPECT_EQ(expected_value, - *static_cast(TF_TensorData(value_zero.get()))); -} - -template -void RegisterParallelDevice( - TFE_Context* context, const char* device_name, - const std::array& underlying_devices, - TF_Status* status) { - TFE_CustomDevice device; - void* device_info; - tensorflow::eager::AllocateParallelDevice( - device_name, underlying_devices.data(), underlying_devices.size(), - &device, &device_info); - TFE_RegisterCustomDevice(context, device, device_name, device_info, status); -} - -// Create and modify a variable placed on a parallel device which composes -// `first_device` and `second_device`. -void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, - const char* second_device) { - // Register the custom device - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::array underlying_devices{first_device, second_device}; - RegisterParallelDevice(context, device_name, underlying_devices, - status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - // Create a variable handle (uninitialized to start) placed on the parallel - // device. - std::function variable_deleter = [&](Variable* to_delete) { - to_delete->Destroy(context, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - delete to_delete; - }; - std::unique_ptr variable( - Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name, - status.get()), - variable_deleter); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - // Assign an initial value to the variable, implicitly mirroring it to each - // component device. - { - TensorHandlePtr initial_value = FloatTensorHandle(20., status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - variable->Assign(context, initial_value.get(), status.get()); - } - - // Read from the variable and verify that we have a parallel tensor. - { - TensorHandlePtr read = variable->Read(context, status.get()); - std::array components; - ExtractPerDeviceValues(context, read.get(), &components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - ExpectScalarEq(components[0].get(), 20.); - ExpectScalarEq(components[1].get(), 20.); - - std::string first_device = - TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); - ASSERT_EQ(underlying_devices[0], first_device); - std::string second_device = - TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); - ASSERT_EQ(underlying_devices[1], second_device); - } - - // Add a parallel tensor with different values on each device to the variable. - { - TensorHandlePtr value_one(FloatTensorHandle(3., status.get())); - TensorHandlePtr value_two(FloatTensorHandle(-2., status.get())); - std::array components{value_one.get(), - value_two.get()}; - TensorHandlePtr combined_value = - CreatePerDeviceValues(context, components, device_name, status.get()); - variable->AssignAdd(context, combined_value.get(), status.get()); - } - - // Read the variable and verify that each component has the right modified - // value. - { - TensorHandlePtr read = variable->Read(context, status.get()); - std::array components; - ExtractPerDeviceValues(context, read.get(), &components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - ExpectScalarEq(components[0].get(), 23.); - ExpectScalarEq(components[1].get(), 18.); - - std::string first_device = - TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); - ASSERT_EQ(underlying_devices[0], first_device); - std::string second_device = - TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); - ASSERT_EQ(underlying_devices[1], second_device); - } - // Compute the device ID twice and verify the result - for (int i = 0; i < 2; ++i) { - std::unique_ptr op( - TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - TFE_OpSetDevice(op.get(), device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - TFE_TensorHandle* result_handle; - int num_retvals = 1; - TFE_Execute(op.get(), &result_handle, &num_retvals, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - std::array components; - ExtractPerDeviceValues(context, result_handle, &components, status.get()); - TFE_DeleteTensorHandle(result_handle); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - ExpectScalarEq(components[0].get(), 0); - ExpectScalarEq(components[1].get(), 1); - std::string first_device = - TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); - ASSERT_EQ(underlying_devices[0], first_device); - std::string second_device = - TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); - ASSERT_EQ(underlying_devices[1], second_device); - } -} - TEST(PARALLEL_DEVICE, TestBasicCPU) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc new file mode 100644 index 00000000000..fba47865c36 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc @@ -0,0 +1,308 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/test.h" + +// NOTE(allenl): These tests currently go through TFE_Execute and so are +// integration testing rather than purely testing the parallel device. They +// correspond fairly well to the implementation, but testing the C++ directly is +// another option. + + +Variable* Variable::Create(TFE_Context* context, TF_DataType type, + const int64_t* dims, const int num_dims, + const char* device, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op.get(), "dtype", type); + TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status); + TFE_OpSetAttrString(op.get(), "container", "", 0); + // Use the special GUID for no buffer sharing + // + // TODO(allenl): Should we provide a better API for this? AFAIK this is the + // only reasonable way to make variables with no aliasing using the eager C + // API. + std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(), + no_sharing.length()); + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op.get(), &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return new Variable(var_handle, type); +} + +void Variable::Destroy(TFE_Context* context, TF_Status* status) { + // Free the backing buffer for the variable. + std::unique_ptr op( + TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return; + int num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; + // Delete the variable handle itself. + TFE_DeleteTensorHandle(handle_); +} + +TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op.get(), "dtype", type_); + int num_retvals = 1; + TFE_TensorHandle* var_value = nullptr; + TFE_Execute(op.get(), &var_value, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(var_value); +} + +void Variable::GeneralAssignment(const char* op_name, TFE_Context* context, + TFE_TensorHandle* value, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, op_name, status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetAttrType(op.get(), "dtype", type_); + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpAddInput(op.get(), value, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + + int num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; +} + +void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status) { + GeneralAssignment("AssignAddVariableOp", context, value, status); +} + +void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status) { + GeneralAssignment("AssignVariableOp", context, value, status); +} + +// Passed to `TF_NewTensor` to indicate how an array of floats should be +// deleted. +static void FloatDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + +// Creates a TFE_TensorHandle with value `v`. +TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) { + const int num_bytes = sizeof(float); + float* values = new float[1]; + values[0] = v; + std::unique_ptr tensor( + TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator, + nullptr), + TF_DeleteTensor); + return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); +} + +// Creates a rank-one TFE_TensorHandle with value `v`. +TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, + TF_Status* status) { + const int num_bytes = v.size() * sizeof(float); + float* values = new float[v.size()]; + memcpy(values, v.data(), num_bytes); + int64_t dims = v.size(); + std::unique_ptr tensor( + TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes, + &FloatDeallocator, nullptr), + TF_DeleteTensor); + return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); +} + +// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. +template +void ExtractPerDeviceValues( + TFE_Context* context, TFE_TensorHandle* input, + std::array* components, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas); + TFE_OpAddInput(op.get(), input, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(input, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return; + + TFE_TensorHandle* result_handles[num_replicas]; + int num_retvals = num_replicas; + TFE_Execute(op.get(), result_handles, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; + for (int i = 0; i < num_replicas; ++i) { + (*components)[i].reset(result_handles[i]); + } +} + +TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, + TFE_TensorHandle* second, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "Mul", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), first, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), second, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + const char* first_device = TFE_TensorHandleDeviceName(first, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(op.get(), first_device, status); + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +// Create and modify a variable placed on a parallel device which composes +// `first_device` and `second_device`. +void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, + const char* second_device) { + // Register the custom device + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::array underlying_devices{first_device, second_device}; + RegisterParallelDevice(context, device_name, underlying_devices, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a variable handle (uninitialized to start) placed on the parallel + // device. + std::function variable_deleter = [&](Variable* to_delete) { + to_delete->Destroy(context, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + delete to_delete; + }; + std::unique_ptr variable( + Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name, + status.get()), + variable_deleter); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Assign an initial value to the variable, implicitly mirroring it to each + // component device. + { + TensorHandlePtr initial_value = FloatTensorHandle(20., status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + variable->Assign(context, initial_value.get(), status.get()); + } + + // Read from the variable and verify that we have a parallel tensor. + { + TensorHandlePtr read = variable->Read(context, status.get()); + std::array components; + ExtractPerDeviceValues(context, read.get(), &components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 20.); + ExpectScalarEq(components[1].get(), 20.); + + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } + + // Add a parallel tensor with different values on each device to the variable. + { + TensorHandlePtr value_one(FloatTensorHandle(3., status.get())); + TensorHandlePtr value_two(FloatTensorHandle(-2., status.get())); + std::array components{value_one.get(), + value_two.get()}; + TensorHandlePtr combined_value = + CreatePerDeviceValues(context, components, device_name, status.get()); + variable->AssignAdd(context, combined_value.get(), status.get()); + } + + // Read the variable and verify that each component has the right modified + // value. + { + TensorHandlePtr read = variable->Read(context, status.get()); + std::array components; + ExtractPerDeviceValues(context, read.get(), &components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 23.); + ExpectScalarEq(components[1].get(), 18.); + + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } + // Compute the device ID twice and verify the result + for (int i = 0; i < 2; ++i) { + std::unique_ptr op( + TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetDevice(op.get(), device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components; + ExtractPerDeviceValues(context, result_handle, &components, status.get()); + TFE_DeleteTensorHandle(result_handle); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 0); + ExpectScalarEq(components[1].get(), 1); + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } +} diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h new file mode 100644 index 00000000000..fdd21087949 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h @@ -0,0 +1,174 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ +#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ + +#include "tensorflow/c/eager/parallel_device/parallel_device.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/test.h" + + +// Functor for making unique_ptr to TFE_TensorHandle slightly more +// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second +// template argument requires passing a function pointer to +// TFE_DeleteTensorHandle when constructing the unique_ptr. +class TensorHandleDeleter { + public: + void operator()(TFE_TensorHandle* to_delete) { + TFE_DeleteTensorHandle(to_delete); + } +}; + +using TensorHandlePtr = std::unique_ptr; + +// A helper for performing common operations on variables. A much more +// restricted stand-in for tf.Variable in Python. +class Variable { + public: + // Construct a Variable from a resource-dtype TFE_TensorHandle and an + // indication of the dtype of the variable's value. + // + // Note that creating this resource-dtype handle can fail, so `Create` is a + // separate static method which returns a status. + Variable(TFE_TensorHandle* handle, TF_DataType type) + : handle_(handle), type_(type) {} + + // Helper for constructing a resource handle and wrapping it in a `Variable` + // object. + static Variable* Create(TFE_Context* context, TF_DataType type, + const int64_t* dims, const int num_dims, + const char* device, TF_Status* status); + // Dereferences the backing buffer for the variable. Note that since this can + // fail (it runs operations), it must be called explicitly and the resulting + // `status` checked. + void Destroy(TFE_Context* context, TF_Status* status); + + // Reads from the variable. + TensorHandlePtr Read(TFE_Context* context, TF_Status* status); + // Assigns a new value to the variable. + void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status); + // Adds `value` to the existing value of the variable. + void AssignAdd(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status); + + private: + // Helper for running any single-argument assignment ops (Assign, AssignAdd, + // AssignSub, ...). + void GeneralAssignment(const char* op_name, TFE_Context* context, + TFE_TensorHandle* value, TF_Status* status); + + // The a handle for the resource-dtype tensor pointing to the variable's + // buffer. + TFE_TensorHandle* handle_; + // The dtype of the variable's buffer (input dtype for assignments, output + // dtype of read operations). + TF_DataType type_; +}; + +// Creates a TFE_TensorHandle with value `v`. +TensorHandlePtr FloatTensorHandle(float v, TF_Status* status); + +// Creates a rank-one TFE_TensorHandle with value `v`. +TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, + TF_Status* status); + +// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. +template +void ExtractPerDeviceValues( + TFE_Context* context, TFE_TensorHandle* input, + std::array* components, TF_Status* status); + +// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle. +template +TensorHandlePtr CreatePerDeviceValues( + TFE_Context* context, + const std::array& components, + const char* device, TF_Status* status); + +TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, + TFE_TensorHandle* second, TF_Status* status); + +// Assert that `handle` is equal to `expected_value`. +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value); + +template +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array& underlying_devices, + TF_Status* status); + +// Create and modify a variable placed on a parallel device which composes +// `first_device` and `second_device`. +void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, + const char* second_device); + +// Implementations of templated functions ****************************** + +template +TensorHandlePtr CreatePerDeviceValues( + TFE_Context* context, + const std::array& components, + const char* device, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrInt(op.get(), "N", num_replicas); + for (int i = 0; i < num_replicas; ++i) { + TFE_OpAddInput(op.get(), components[i], status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr value_zero( + TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + EXPECT_EQ(expected_value, + *static_cast(TF_TensorData(value_zero.get()))); +} + +template +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array& underlying_devices, + TF_Status* status) { + TFE_CustomDevice device; + void* device_info; + tensorflow::eager::AllocateParallelDevice( + device_name, underlying_devices.data(), underlying_devices.size(), + &device, &device_info); + TFE_RegisterCustomDevice(context, device, device_name, device_info, status); +} + +#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_