ParallelDevice: Sync executors when returning non-parallel TensorHandles, add remote tests
The actual delta isn't huge; I'm moving some test utilities to a testlib since the remote tests need them. The remote tests are in a different target because they need to disable global heap checking, which I'd like to keep on for the rest of the tests. PiperOrigin-RevId: 313698670 Change-Id: I846294a748e3b007eba0472901b0e58358b8edd5
This commit is contained in:
parent
3ed7b0732e
commit
69c0447f01
|
@ -39,9 +39,11 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
cc_library(
|
||||||
name = "parallel_device_test",
|
name = "parallel_device_testlib",
|
||||||
srcs = ["parallel_device_test.cc"],
|
testonly = 1,
|
||||||
|
srcs = ["parallel_device_testlib.cc"],
|
||||||
|
hdrs = ["parallel_device_testlib.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":parallel_device",
|
":parallel_device",
|
||||||
":parallel_device_ops",
|
":parallel_device_ops",
|
||||||
|
@ -49,12 +51,49 @@ tf_cc_test(
|
||||||
"//tensorflow/c:c_api_experimental",
|
"//tensorflow/c:c_api_experimental",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/c/eager:c_api_experimental",
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "parallel_device_test",
|
||||||
|
srcs = ["parallel_device_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":parallel_device",
|
||||||
|
":parallel_device_ops",
|
||||||
|
":parallel_device_testlib",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c:c_api_experimental",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "parallel_device_remote_test",
|
||||||
|
srcs = ["parallel_device_remote_test.cc"],
|
||||||
|
# TODO(b/136478427): Enable global heap checking when servers shut down
|
||||||
|
# cleanly.
|
||||||
|
args = ["--heap_check=local"],
|
||||||
|
deps = [
|
||||||
|
":parallel_device",
|
||||||
|
":parallel_device_ops",
|
||||||
|
":parallel_device_testlib",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c:c_api_experimental",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||||
# to TensorFlow by default, just used in a few tests.
|
# to TensorFlow by default, just used in a few tests.
|
||||||
filegroup(
|
filegroup(
|
||||||
|
|
|
@ -319,6 +319,11 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||||
std::vector<MaybeParallelTensorOwned> outputs;
|
std::vector<MaybeParallelTensorOwned> outputs;
|
||||||
outputs.reserve(t->num_tensors());
|
outputs.reserve(t->num_tensors());
|
||||||
for (int i = 0; i < t->num_tensors(); ++i) {
|
for (int i = 0; i < t->num_tensors(); ++i) {
|
||||||
|
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||||
|
// necessary. Currently async+remote is missing cross-executor
|
||||||
|
// coordination.
|
||||||
|
TFE_ExecutorWaitForAllPendingNodes(executors_[i].get(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
TensorHandlePtr this_output(
|
TensorHandlePtr this_output(
|
||||||
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
|
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
|
||||||
outputs.emplace_back(std::move(this_output));
|
outputs.emplace_back(std::move(this_output));
|
||||||
|
|
|
@ -0,0 +1,147 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||||
|
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
|
||||||
|
tensorflow::ServerDef server_def;
|
||||||
|
server_def.set_protocol("grpc");
|
||||||
|
server_def.set_job_name(job_name);
|
||||||
|
server_def.set_task_index(0);
|
||||||
|
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||||
|
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||||
|
job_def->set_name(job_name);
|
||||||
|
for (int i = 0; i < num_tasks; i++) {
|
||||||
|
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||||
|
job_def->mutable_tasks()->insert(
|
||||||
|
{i, tensorflow::strings::StrCat("localhost", ":", port)});
|
||||||
|
}
|
||||||
|
return server_def;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PARALLEL_DEVICE, TestRemoteBasic) {
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
std::string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server1->Start().ok());
|
||||||
|
|
||||||
|
server_def.set_task_index(2);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server2->Start().ok());
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
|
||||||
|
serialized.size(), status.get());
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
BasicTestsForTwoDevices(context.get(),
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0",
|
||||||
|
"/job:worker/replica:0/task:2/device:CPU:0");
|
||||||
|
|
||||||
|
worker_server1.release();
|
||||||
|
worker_server2.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
std::string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server1->Start().ok());
|
||||||
|
|
||||||
|
server_def.set_task_index(2);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server2->Start().ok());
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
|
||||||
|
serialized.size(), status.get());
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
|
||||||
|
const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
|
||||||
|
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||||
|
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||||
|
status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
|
||||||
|
value_two.get()};
|
||||||
|
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||||
|
context.get(), in_components, device_name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Loop to make synchronization failures more deterministic
|
||||||
|
for (int i = 0; i < 100; ++i) {
|
||||||
|
TensorHandlePtr multiply_result(
|
||||||
|
Multiply(context.get(), combined_value.get(), combined_value.get(),
|
||||||
|
status.get()));
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::array<TensorHandlePtr, 2> out_components;
|
||||||
|
ExtractPerDeviceValues(context.get(), multiply_result.get(),
|
||||||
|
&out_components, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
ExpectScalarEq<float>(out_components[0].get(), 9.);
|
||||||
|
ExpectScalarEq<float>(out_components[1].get(), 4.);
|
||||||
|
}
|
||||||
|
|
||||||
|
worker_server1.release();
|
||||||
|
worker_server2.release();
|
||||||
|
}
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "tensorflow/c/c_api_experimental.h"
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||||
|
@ -28,390 +29,6 @@ limitations under the License.
|
||||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||||
// another option.
|
// another option.
|
||||||
|
|
||||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
|
||||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
|
||||||
// template argument requires passing a function pointer to
|
|
||||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
|
||||||
class TensorHandleDeleter {
|
|
||||||
public:
|
|
||||||
void operator()(TFE_TensorHandle* to_delete) {
|
|
||||||
TFE_DeleteTensorHandle(to_delete);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
|
||||||
|
|
||||||
// A helper for performing common operations on variables. A much more
|
|
||||||
// restricted stand-in for tf.Variable in Python.
|
|
||||||
class Variable {
|
|
||||||
public:
|
|
||||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
|
||||||
// indication of the dtype of the variable's value.
|
|
||||||
//
|
|
||||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
|
||||||
// separate static method which returns a status.
|
|
||||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
|
||||||
: handle_(handle), type_(type) {}
|
|
||||||
|
|
||||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
|
||||||
// object.
|
|
||||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
|
||||||
const int64_t* dims, const int num_dims,
|
|
||||||
const char* device, TF_Status* status);
|
|
||||||
// Dereferences the backing buffer for the variable. Note that since this can
|
|
||||||
// fail (it runs operations), it must be called explicitly and the resulting
|
|
||||||
// `status` checked.
|
|
||||||
void Destroy(TFE_Context* context, TF_Status* status);
|
|
||||||
|
|
||||||
// Reads from the variable.
|
|
||||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
|
||||||
// Assigns a new value to the variable.
|
|
||||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
|
||||||
// Adds `value` to the existing value of the variable.
|
|
||||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
|
||||||
// AssignSub, ...).
|
|
||||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
|
||||||
TFE_TensorHandle* value, TF_Status* status);
|
|
||||||
|
|
||||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
|
||||||
// buffer.
|
|
||||||
TFE_TensorHandle* handle_;
|
|
||||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
|
||||||
// dtype of read operations).
|
|
||||||
TF_DataType type_;
|
|
||||||
};
|
|
||||||
|
|
||||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
|
||||||
const int64_t* dims, const int num_dims,
|
|
||||||
const char* device, TF_Status* status) {
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
|
||||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
|
||||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
|
||||||
// Use the special GUID for no buffer sharing
|
|
||||||
//
|
|
||||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
|
||||||
// only reasonable way to make variables with no aliasing using the eager C
|
|
||||||
// API.
|
|
||||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
|
||||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
|
||||||
no_sharing.length());
|
|
||||||
TFE_OpSetDevice(op.get(), device, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_TensorHandle* var_handle = nullptr;
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
return new Variable(var_handle, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
|
||||||
// Free the backing buffer for the variable.
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
TFE_OpAddInput(op.get(), handle_, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
TFE_OpSetDevice(op.get(), device, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
int num_retvals = 0;
|
|
||||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
// Delete the variable handle itself.
|
|
||||||
TFE_DeleteTensorHandle(handle_);
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpAddInput(op.get(), handle_, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetDevice(op.get(), device, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_TensorHandle* var_value = nullptr;
|
|
||||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
return TensorHandlePtr(var_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
|
||||||
TFE_TensorHandle* value, TF_Status* status) {
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
|
||||||
TFE_OpAddInput(op.get(), handle_, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
TFE_OpAddInput(op.get(), value, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
TFE_OpSetDevice(op.get(), device, status);
|
|
||||||
|
|
||||||
int num_retvals = 0;
|
|
||||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
|
||||||
TF_Status* status) {
|
|
||||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
|
||||||
TF_Status* status) {
|
|
||||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
|
||||||
// deleted.
|
|
||||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
|
||||||
delete[] static_cast<float*>(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a TFE_TensorHandle with value `v`.
|
|
||||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
|
||||||
const int num_bytes = sizeof(float);
|
|
||||||
float* values = new float[1];
|
|
||||||
values[0] = v;
|
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
|
||||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
|
||||||
nullptr),
|
|
||||||
TF_DeleteTensor);
|
|
||||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
|
||||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
|
||||||
TF_Status* status) {
|
|
||||||
const int num_bytes = v.size() * sizeof(float);
|
|
||||||
float* values = new float[v.size()];
|
|
||||||
memcpy(values, v.data(), num_bytes);
|
|
||||||
int64_t dims = v.size();
|
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
|
||||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
|
||||||
&FloatDeallocator, nullptr),
|
|
||||||
TF_DeleteTensor);
|
|
||||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
|
||||||
template <std::size_t num_replicas>
|
|
||||||
void ExtractPerDeviceValues(
|
|
||||||
TFE_Context* context, TFE_TensorHandle* input,
|
|
||||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
|
||||||
TFE_OpAddInput(op.get(), input, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
TFE_OpSetDevice(op.get(), device, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
|
|
||||||
TFE_TensorHandle* result_handles[num_replicas];
|
|
||||||
int num_retvals = num_replicas;
|
|
||||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
|
||||||
for (int i = 0; i < num_replicas; ++i) {
|
|
||||||
(*components)[i].reset(result_handles[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
|
||||||
template <std::size_t num_replicas>
|
|
||||||
TensorHandlePtr CreatePerDeviceValues(
|
|
||||||
TFE_Context* context,
|
|
||||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
|
||||||
const char* device, TF_Status* status) {
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
|
||||||
for (int i = 0; i < num_replicas; ++i) {
|
|
||||||
TFE_OpAddInput(op.get(), components[i], status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
}
|
|
||||||
TFE_OpSetDevice(op.get(), device, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
|
|
||||||
TFE_TensorHandle* result_handle;
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
return TensorHandlePtr(result_handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
|
||||||
TFE_TensorHandle* second, TF_Status* status) {
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpAddInput(op.get(), first, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpAddInput(op.get(), second, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetDevice(op.get(), first_device, status);
|
|
||||||
|
|
||||||
TFE_TensorHandle* result_handle;
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
return TensorHandlePtr(result_handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that `handle` is equal to `expected_value`.
|
|
||||||
template <typename value_type>
|
|
||||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
|
||||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
EXPECT_EQ(expected_value,
|
|
||||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <std::size_t num_devices>
|
|
||||||
void RegisterParallelDevice(
|
|
||||||
TFE_Context* context, const char* device_name,
|
|
||||||
const std::array<const char*, num_devices>& underlying_devices,
|
|
||||||
TF_Status* status) {
|
|
||||||
TFE_CustomDevice device;
|
|
||||||
void* device_info;
|
|
||||||
tensorflow::eager::AllocateParallelDevice(
|
|
||||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
|
||||||
&device, &device_info);
|
|
||||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create and modify a variable placed on a parallel device which composes
|
|
||||||
// `first_device` and `second_device`.
|
|
||||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|
||||||
const char* second_device) {
|
|
||||||
// Register the custom device
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
|
||||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
|
||||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
|
||||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
|
||||||
status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
|
||||||
// device.
|
|
||||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
|
||||||
to_delete->Destroy(context, status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
delete to_delete;
|
|
||||||
};
|
|
||||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
|
||||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
|
||||||
status.get()),
|
|
||||||
variable_deleter);
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
|
||||||
// component device.
|
|
||||||
{
|
|
||||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
variable->Assign(context, initial_value.get(), status.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read from the variable and verify that we have a parallel tensor.
|
|
||||||
{
|
|
||||||
TensorHandlePtr read = variable->Read(context, status.get());
|
|
||||||
std::array<TensorHandlePtr, 2> components;
|
|
||||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
ExpectScalarEq<float>(components[0].get(), 20.);
|
|
||||||
ExpectScalarEq<float>(components[1].get(), 20.);
|
|
||||||
|
|
||||||
std::string first_device =
|
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
|
||||||
ASSERT_EQ(underlying_devices[0], first_device);
|
|
||||||
std::string second_device =
|
|
||||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
|
||||||
ASSERT_EQ(underlying_devices[1], second_device);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add a parallel tensor with different values on each device to the variable.
|
|
||||||
{
|
|
||||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
|
||||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
|
||||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
|
||||||
value_two.get()};
|
|
||||||
TensorHandlePtr combined_value =
|
|
||||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
|
||||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the variable and verify that each component has the right modified
|
|
||||||
// value.
|
|
||||||
{
|
|
||||||
TensorHandlePtr read = variable->Read(context, status.get());
|
|
||||||
std::array<TensorHandlePtr, 2> components;
|
|
||||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
ExpectScalarEq<float>(components[0].get(), 23.);
|
|
||||||
ExpectScalarEq<float>(components[1].get(), 18.);
|
|
||||||
|
|
||||||
std::string first_device =
|
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
|
||||||
ASSERT_EQ(underlying_devices[0], first_device);
|
|
||||||
std::string second_device =
|
|
||||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
|
||||||
ASSERT_EQ(underlying_devices[1], second_device);
|
|
||||||
}
|
|
||||||
// Compute the device ID twice and verify the result
|
|
||||||
for (int i = 0; i < 2; ++i) {
|
|
||||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
|
||||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
TFE_TensorHandle* result_handle;
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
std::array<TensorHandlePtr, 2> components;
|
|
||||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
|
||||||
TFE_DeleteTensorHandle(result_handle);
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
|
||||||
|
|
||||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
|
||||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
|
||||||
std::string first_device =
|
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
|
||||||
ASSERT_EQ(underlying_devices[0], first_device);
|
|
||||||
std::string second_device =
|
|
||||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
|
||||||
ASSERT_EQ(underlying_devices[1], second_device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
|
|
@ -0,0 +1,308 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||||
|
// integration testing rather than purely testing the parallel device. They
|
||||||
|
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||||
|
// another option.
|
||||||
|
|
||||||
|
|
||||||
|
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||||
|
const int64_t* dims, const int num_dims,
|
||||||
|
const char* device, TF_Status* status) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||||
|
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||||
|
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||||
|
// Use the special GUID for no buffer sharing
|
||||||
|
//
|
||||||
|
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||||
|
// only reasonable way to make variables with no aliasing using the eager C
|
||||||
|
// API.
|
||||||
|
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||||
|
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||||
|
no_sharing.length());
|
||||||
|
TFE_OpSetDevice(op.get(), device, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_TensorHandle* var_handle = nullptr;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
return new Variable(var_handle, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||||
|
// Free the backing buffer for the variable.
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
TFE_OpAddInput(op.get(), handle_, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
TFE_OpSetDevice(op.get(), device, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
int num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
// Delete the variable handle itself.
|
||||||
|
TFE_DeleteTensorHandle(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpAddInput(op.get(), handle_, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetDevice(op.get(), device, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_TensorHandle* var_value = nullptr;
|
||||||
|
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
return TensorHandlePtr(var_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||||
|
TFE_TensorHandle* value, TF_Status* status) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||||
|
TFE_OpAddInput(op.get(), handle_, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
TFE_OpAddInput(op.get(), value, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
TFE_OpSetDevice(op.get(), device, status);
|
||||||
|
|
||||||
|
int num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||||
|
TF_Status* status) {
|
||||||
|
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||||
|
TF_Status* status) {
|
||||||
|
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||||
|
// deleted.
|
||||||
|
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||||
|
delete[] static_cast<float*>(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a TFE_TensorHandle with value `v`.
|
||||||
|
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||||
|
const int num_bytes = sizeof(float);
|
||||||
|
float* values = new float[1];
|
||||||
|
values[0] = v;
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||||
|
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||||
|
nullptr),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||||
|
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||||
|
TF_Status* status) {
|
||||||
|
const int num_bytes = v.size() * sizeof(float);
|
||||||
|
float* values = new float[v.size()];
|
||||||
|
memcpy(values, v.data(), num_bytes);
|
||||||
|
int64_t dims = v.size();
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||||
|
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||||
|
&FloatDeallocator, nullptr),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||||
|
template <std::size_t num_replicas>
|
||||||
|
void ExtractPerDeviceValues(
|
||||||
|
TFE_Context* context, TFE_TensorHandle* input,
|
||||||
|
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||||
|
TFE_OpAddInput(op.get(), input, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
TFE_OpSetDevice(op.get(), device, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
|
TFE_TensorHandle* result_handles[num_replicas];
|
||||||
|
int num_retvals = num_replicas;
|
||||||
|
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
for (int i = 0; i < num_replicas; ++i) {
|
||||||
|
(*components)[i].reset(result_handles[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||||
|
TFE_TensorHandle* second, TF_Status* status) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpAddInput(op.get(), first, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpAddInput(op.get(), second, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetDevice(op.get(), first_device, status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* result_handle;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
return TensorHandlePtr(result_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and modify a variable placed on a parallel device which composes
|
||||||
|
// `first_device` and `second_device`.
|
||||||
|
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||||
|
const char* second_device) {
|
||||||
|
// Register the custom device
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||||
|
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||||
|
status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||||
|
// device.
|
||||||
|
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||||
|
to_delete->Destroy(context, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
delete to_delete;
|
||||||
|
};
|
||||||
|
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||||
|
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||||
|
status.get()),
|
||||||
|
variable_deleter);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||||
|
// component device.
|
||||||
|
{
|
||||||
|
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
variable->Assign(context, initial_value.get(), status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read from the variable and verify that we have a parallel tensor.
|
||||||
|
{
|
||||||
|
TensorHandlePtr read = variable->Read(context, status.get());
|
||||||
|
std::array<TensorHandlePtr, 2> components;
|
||||||
|
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||||
|
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||||
|
|
||||||
|
std::string first_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[0], first_device);
|
||||||
|
std::string second_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[1], second_device);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a parallel tensor with different values on each device to the variable.
|
||||||
|
{
|
||||||
|
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||||
|
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||||
|
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||||
|
value_two.get()};
|
||||||
|
TensorHandlePtr combined_value =
|
||||||
|
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||||
|
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the variable and verify that each component has the right modified
|
||||||
|
// value.
|
||||||
|
{
|
||||||
|
TensorHandlePtr read = variable->Read(context, status.get());
|
||||||
|
std::array<TensorHandlePtr, 2> components;
|
||||||
|
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||||
|
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||||
|
|
||||||
|
std::string first_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[0], first_device);
|
||||||
|
std::string second_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[1], second_device);
|
||||||
|
}
|
||||||
|
// Compute the device ID twice and verify the result
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_TensorHandle* result_handle;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::array<TensorHandlePtr, 2> components;
|
||||||
|
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||||
|
TFE_DeleteTensorHandle(result_handle);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||||
|
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||||
|
std::string first_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[0], first_device);
|
||||||
|
std::string second_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[1], second_device);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,174 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
|
||||||
|
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||||
|
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||||
|
// template argument requires passing a function pointer to
|
||||||
|
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||||
|
class TensorHandleDeleter {
|
||||||
|
public:
|
||||||
|
void operator()(TFE_TensorHandle* to_delete) {
|
||||||
|
TFE_DeleteTensorHandle(to_delete);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||||
|
|
||||||
|
// A helper for performing common operations on variables. A much more
|
||||||
|
// restricted stand-in for tf.Variable in Python.
|
||||||
|
class Variable {
|
||||||
|
public:
|
||||||
|
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||||
|
// indication of the dtype of the variable's value.
|
||||||
|
//
|
||||||
|
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||||
|
// separate static method which returns a status.
|
||||||
|
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||||
|
: handle_(handle), type_(type) {}
|
||||||
|
|
||||||
|
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||||
|
// object.
|
||||||
|
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||||
|
const int64_t* dims, const int num_dims,
|
||||||
|
const char* device, TF_Status* status);
|
||||||
|
// Dereferences the backing buffer for the variable. Note that since this can
|
||||||
|
// fail (it runs operations), it must be called explicitly and the resulting
|
||||||
|
// `status` checked.
|
||||||
|
void Destroy(TFE_Context* context, TF_Status* status);
|
||||||
|
|
||||||
|
// Reads from the variable.
|
||||||
|
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||||
|
// Assigns a new value to the variable.
|
||||||
|
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||||
|
// Adds `value` to the existing value of the variable.
|
||||||
|
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||||
|
// AssignSub, ...).
|
||||||
|
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||||
|
TFE_TensorHandle* value, TF_Status* status);
|
||||||
|
|
||||||
|
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||||
|
// buffer.
|
||||||
|
TFE_TensorHandle* handle_;
|
||||||
|
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||||
|
// dtype of read operations).
|
||||||
|
TF_DataType type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates a TFE_TensorHandle with value `v`.
|
||||||
|
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
|
||||||
|
|
||||||
|
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||||
|
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||||
|
template <std::size_t num_replicas>
|
||||||
|
void ExtractPerDeviceValues(
|
||||||
|
TFE_Context* context, TFE_TensorHandle* input,
|
||||||
|
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
|
||||||
|
|
||||||
|
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||||
|
template <std::size_t num_replicas>
|
||||||
|
TensorHandlePtr CreatePerDeviceValues(
|
||||||
|
TFE_Context* context,
|
||||||
|
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||||
|
const char* device, TF_Status* status);
|
||||||
|
|
||||||
|
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||||
|
TFE_TensorHandle* second, TF_Status* status);
|
||||||
|
|
||||||
|
// Assert that `handle` is equal to `expected_value`.
|
||||||
|
template <typename value_type>
|
||||||
|
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
|
||||||
|
|
||||||
|
template <std::size_t num_devices>
|
||||||
|
void RegisterParallelDevice(
|
||||||
|
TFE_Context* context, const char* device_name,
|
||||||
|
const std::array<const char*, num_devices>& underlying_devices,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Create and modify a variable placed on a parallel device which composes
|
||||||
|
// `first_device` and `second_device`.
|
||||||
|
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||||
|
const char* second_device);
|
||||||
|
|
||||||
|
// Implementations of templated functions ******************************
|
||||||
|
|
||||||
|
template <std::size_t num_replicas>
|
||||||
|
TensorHandlePtr CreatePerDeviceValues(
|
||||||
|
TFE_Context* context,
|
||||||
|
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||||
|
const char* device, TF_Status* status) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||||
|
for (int i = 0; i < num_replicas; ++i) {
|
||||||
|
TFE_OpAddInput(op.get(), components[i], status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
}
|
||||||
|
TFE_OpSetDevice(op.get(), device, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
|
||||||
|
TFE_TensorHandle* result_handle;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
return TensorHandlePtr(result_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename value_type>
|
||||||
|
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||||
|
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
EXPECT_EQ(expected_value,
|
||||||
|
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <std::size_t num_devices>
|
||||||
|
void RegisterParallelDevice(
|
||||||
|
TFE_Context* context, const char* device_name,
|
||||||
|
const std::array<const char*, num_devices>& underlying_devices,
|
||||||
|
TF_Status* status) {
|
||||||
|
TFE_CustomDevice device;
|
||||||
|
void* device_info;
|
||||||
|
tensorflow::eager::AllocateParallelDevice(
|
||||||
|
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||||
|
&device, &device_info);
|
||||||
|
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
|
Loading…
Reference in New Issue