Support running a function with packed input handles through C APIs.

Introduce a C API TFE_CreatePackedTensorHandle which creates a TFE_TensorHandle referring to multiple TFE_TensorHandles.

PiperOrigin-RevId: 310610230
Change-Id: Icc0ffd5c58ad7780eca38d552c1a2f4617f04891
This commit is contained in:
Yujing Zhang 2020-05-08 12:47:17 -07:00 committed by TensorFlower Gardener
parent 696f2a8bd7
commit 7e6ea21148
21 changed files with 586 additions and 70 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/lib/monitoring/counter.h"
@ -638,3 +639,21 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
return tensorflow::wrap(
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
}
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
TFE_TensorHandle** handles,
int* num_handles,
TF_Status* status) {
std::vector<tensorflow::TensorHandle*> tensor_handles;
tensor_handles.reserve(*num_handles);
for (int i = 0; i < *num_handles; ++i) {
tensor_handles.push_back(
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::TensorHandle* handle = nullptr;
status->status = tensorflow::TensorHandle::CreatePackedHandle(
std::move(tensor_handles), context, &handle);
return tensorflow::wrap(handle);
}

View File

@ -541,6 +541,14 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
// Create a packed TensorHandle with the given list of TensorHandles.
// If `handles` are on the same device, assign the same device to the packed
// handle; if `handles` are on different deivces, assign a CompositeDevice to
// it.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -351,6 +351,192 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
/*heavy_load_on_streaming_rpc=*/true);
}
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
TEST(CAPI, TestFunctionWithPackedInput) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);

View File

@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
}
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
TF_Status* status) {
// Create the variable handle.
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op, &var_handle, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(1, num_retvals);
// Assign 'value' to it.
op = TFE_NewOp(ctx, "AssignVariableOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
value_handle(TFE_NewTensorHandle(t.get(), status),
TFE_DeleteTensorHandle);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(0, num_retvals);
return var_handle;
}
TEST(CAPI, Variables) {
// Variables use resource handles, so this is really a test for resource
// tensor handling.
@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);

View File

@ -133,6 +133,57 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
return th;
}
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
const tensorflow::string& device_name) {
TF_Status* status = TF_NewStatus();
// Create the variable handle.
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
TFE_OpSetAttrString(op, "container", "", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (!device_name.empty()) {
TFE_OpSetDevice(op, device_name.c_str(), status);
}
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op, &var_handle, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(1, num_retvals);
// Assign 'value' to it.
op = TFE_NewOp(ctx, "AssignVariableOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
value_handle(TFE_NewTensorHandle(t.get(), status),
TFE_DeleteTensorHandle);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(0, num_retvals);
TF_DeleteStatus(status);
return var_handle;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();

View File

@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a variable handle referring to a variable with the given initial value
// on the given device.
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
const tensorflow::string& device_name = "");
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);

View File

@ -305,6 +305,7 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":attr_builder",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
@ -369,6 +370,7 @@ cc_library(
":eager_operation",
":kernel_and_device",
":tensor_handle",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/strings",
@ -396,6 +398,24 @@ cc_library(
}) + if_mkl([":mkl_eager_op_rewrite"]),
)
tf_cc_test(
name = "execute_node_test",
srcs = ["execute_node_test.cc"],
deps = [
":context",
":core",
":execute",
":kernel_and_device",
":tensor_handle",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/memory",
],
)
cc_library(
name = "mkl_eager_op_rewrite",
srcs = ["mkl_eager_op_rewrite.cc"],
@ -466,6 +486,7 @@ cc_library(
":eager_operation",
":kernel_and_device",
":tensor_handle",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/strings",

View File

@ -853,6 +853,18 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
return status;
}
Status EagerContext::FindCompositeDeviceFromName(
const char* device_name, CompositeDevice** device) const {
tf_shared_lock l(composite_devices_mu_);
for (const auto& d : composite_devices_) {
if (d.second->name() == device_name) {
*device = d.second.get();
return Status::OK();
}
}
return errors::NotFound("Unknown composite device: ", device_name);
}
Status EagerContext::FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const {
auto dev_it = custom_devices_.find(device_name);
@ -904,8 +916,7 @@ Status EagerContext::FindOrCreateCompositeDevice(
composite_devices_.size(), &s);
TF_RETURN_IF_ERROR(s);
*composite_device = device.get();
// TODO(b/145922293): Add the composite device to the device set of pflr in
// order to make placer recognize it.
pflr_->AddCompositeDevice(*composite_device);
composite_devices_.emplace(hash_key, std::move(device));
return Status::OK();
}

View File

@ -483,6 +483,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
Status FindDeviceFromName(const char* device_name, Device** device) const;
Status FindCompositeDeviceFromName(const char* device_name,
CompositeDevice** device) const;
Status FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const;

View File

@ -180,6 +180,10 @@ TEST_F(EagerContextTest, CompositeDevice) {
&composite_device_0));
EXPECT_EQ(composite_device_0->name(),
"/job:worker/replica:0/task:0/device:COMPOSITE:0");
CompositeDevice* device = nullptr;
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
"/job:worker/replica:0/task:0/device:COMPOSITE:0", &device));
EXPECT_EQ(device, composite_device_0);
CompositeDevice* composite_device_1 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
&composite_device_1));
@ -190,6 +194,12 @@ TEST_F(EagerContextTest, CompositeDevice) {
&composite_device_2));
EXPECT_EQ(composite_device_2->name(),
"/job:worker/replica:0/task:0/device:COMPOSITE:1");
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
"/job:worker/replica:0/task:0/device:COMPOSITE:1", &device));
EXPECT_EQ(device, composite_device_2);
EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName(
"/job:worker/replica:0/task:0/device:COMPOSITE:2", &device)));
}
} // namespace

View File

@ -367,6 +367,7 @@ Status GetOrCreateKernelAndDevice(
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
std::vector<Device*> input_dev_ptrs;
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
std::unordered_map<int, DtypeAndPartialTensorShape>
input_resource_variable_dtypes_and_shapes;
// We can eliminate some overhead by running simple functions using regular
@ -410,6 +411,13 @@ Status GetOrCreateKernelAndDevice(
Device* input_device;
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
input_dev_ptrs.push_back(input_device);
CompositeDevice* composite_device = nullptr;
if (ctx.FindCompositeDeviceFromName(input_device->name().c_str(),
&composite_device)
.ok()) {
composite_devices[input_device->name()] =
composite_device->underlying_devices();
}
cache_key =
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
@ -520,6 +528,7 @@ Status GetOrCreateKernelAndDevice(
#endif // IS_MOBILE_PLATFORM
kernel.reset(new KernelAndDeviceFunc(
flr, ctx.pflr(), std::move(input_dev_ptrs),
std::move(composite_devices),
std::move(input_resource_variable_dtypes_and_shapes), runner,
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
[&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); },

View File

@ -17,6 +17,51 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
#if !defined(IS_MOBILE_PLATFORM)
bool ExecuteNodeArgs::IsRemote(EagerContext* ctx, Device* input_device,
TensorHandle* handle) {
uint64 context_view_id = ctx->GetContextViewId();
if (handle->Type() == TensorHandle::REMOTE ||
handle->HasRemoteMirror(input_device, context_view_id)) {
if (!has_remote_inputs_) {
has_remote_inputs_ = true;
}
return true;
}
return false;
}
#endif // IS_MOBILE_PLATFORM
Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx,
Device* input_device,
TensorHandle* packed_handle) {
int num_handles = packed_handle->NumPackedHandles();
packed_args_.emplace(index, gtl::InlinedVector<TensorValue, 4>(num_handles));
TensorValue* packed_arg_flat = &(packed_args_[index][0]);
for (int i = 0; i < num_handles; ++i) {
TensorHandle* h = nullptr;
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
// We have validated that h->device() is not a CustomDevice when
// constructing a pack TensorHandle.
const Status status =
h->TensorValue(absl::get<Device*>(h->device()), &packed_arg_flat[i]);
if (!status.ok()) {
#if !defined(IS_MOBILE_PLATFORM)
if (IsRemote(ctx, input_device, h)) {
continue;
}
#endif // IS_MOBILE_PLATFORM
if (h->Type() == TensorHandle::PACKED) {
return errors::InvalidArgument(
"Nested packed handles are not supported");
}
return status;
}
}
return Status::OK();
}
Status ExecuteNodeArgs::Init(
EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
const core::RefCountPtr<KernelAndDevice>& kernel) {
@ -35,16 +80,17 @@ Status ExecuteNodeArgs::Init(
Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
if (!s.ok()) {
#if !defined(IS_MOBILE_PLATFORM)
uint64 context_view_id = ctx->GetContextViewId();
if (in->Type() == TensorHandle::REMOTE ||
in->HasRemoteMirror(d, context_view_id)) {
if (!has_remote_inputs_) {
has_remote_inputs_ = true;
}
if (IsRemote(ctx, d, in)) {
continue;
}
#endif
return s;
if (in->Type() != TensorHandle::PACKED) {
return s;
}
if (!has_packed_inputs_) {
has_packed_inputs_ = true;
}
TF_RETURN_IF_ERROR(InitPackedHandle(i, ctx, d, in));
}
}
}
@ -54,24 +100,44 @@ Status ExecuteNodeArgs::Init(
serialize_remote_handle_ =
[ctx, &op_inputs](const FunctionArgIndex& index,
eager::RemoteTensorHandle* handle) -> Status {
if (index.sub_index >= 0) {
return errors::InvalidArgument("Got unexpected sub_index ",
index.sub_index, " for argument ",
index.index);
TensorHandle* h = op_inputs[index.index];
if (op_inputs[index.index]->Type() == TensorHandle::PACKED) {
TF_RETURN_IF_ERROR(
op_inputs[index.index]->ExtractPackedHandle(index.sub_index, &h));
}
VariantDevice variant_device = op_inputs[index.index]->device();
VariantDevice variant_device = h->device();
if (VariantDeviceIsCustom(variant_device)) {
return errors::Internal(
"Custom devices and remote execution are currently not supported "
"together.");
}
Device* device = absl::get<Device*>(variant_device);
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
op_inputs[index.index], handle, device, device->name());
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(h, handle, device,
device->name());
};
}
#endif // !IS_MOBILE_PLATFORM
return Status::OK();
}
Status ExecuteNodeArgs::GetLocalArg(const FunctionArgIndex& index,
Tensor* val) const {
Status s = EagerKernelArgs::GetLocalArg(index, val);
if (s.ok()) {
return Status::OK();
}
if (packed_args_.contains(index.index)) {
Tensor* arg = packed_args_.at(index.index).at(index.sub_index).tensor;
if (arg) {
*val = *arg;
return Status::OK();
} else {
return errors::NotFound("Argument (", index.index, ",", index.sub_index,
") has no local tensor.");
}
} else {
return s;
}
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <cstddef>
#include <memory>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h"
// clang-format on
@ -54,6 +55,8 @@ class ExecuteNodeArgs : public EagerKernelArgs {
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
const core::RefCountPtr<KernelAndDevice>& kernel);
Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
bool HasRemoteOrPackedInputs() const override {
return has_remote_inputs_ || has_packed_inputs_;
};
@ -66,8 +69,20 @@ class ExecuteNodeArgs : public EagerKernelArgs {
#endif // IS_MOBILE_PLATFORM
private:
#if !defined(IS_MOBILE_PLATFORM)
// Returns whether `handle` is a remote handle or has a remote mirror on
// `input_device`
bool IsRemote(EagerContext* ctx, Device* input_device, TensorHandle* handle);
#endif // IS_MOBILE_PLATFORM
// Initialize a packed TensorHandle which is the `index`-th argument.
Status InitPackedHandle(const int index, EagerContext* ctx,
Device* input_device, TensorHandle* packed_handle);
bool has_remote_inputs_ = false;
bool has_packed_inputs_ = false;
// Maps from the index of a packed arg to a list of sub-args.
absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
#if !defined(IS_MOBILE_PLATFORM)
std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
serialize_remote_handle_;

View File

@ -0,0 +1,126 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class TestKernelAndDeviceFunc final : public KernelAndDeviceFunc {
public:
TestKernelAndDeviceFunc(std::vector<Device*> input_devices,
Device* host_cpu_device)
: KernelAndDeviceFunc(
/*flr=*/nullptr, /*pflr=*/nullptr, /*input_devices=*/{},
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
/*runner=*/nullptr, /*collective_executor=*/nullptr,
host_cpu_device, /*name=*/"",
/*rendezvous_creator=*/nullptr, /*get_op_id=*/nullptr),
test_input_devices_(std::move(input_devices)) {}
Device* InputDevice(int i) const override { return test_input_devices_[i]; }
private:
std::vector<Device*> test_input_devices_;
};
TEST(ExecuteNodeTest, ExecuteNodeArgs) {
StaticDeviceMgr device_mgr(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
Device* device0 = device_mgr.ListDevices().at(0);
StaticDeviceMgr remote_device_mgr(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
Device* device1 = remote_device_mgr.ListDevices().at(0);
Status s;
std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice({device0->name(), device1->name()},
/*unique_device_id=*/0, &s);
TF_ASSERT_OK(s);
auto ctx = new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
&device_mgr, false, nullptr, nullptr, nullptr);
DataType dtype = DT_FLOAT;
Tensor t0(dtype, TensorShape({}));
// Create two local TensorHandles
t0.scalar<float>()() = {1.0f};
TensorHandle* h0 =
TensorHandle::CreateLocalHandle(std::move(t0), device0, device0, ctx);
Tensor t1(dtype, TensorShape({}));
t1.scalar<float>()() = {2.0f};
TensorHandle* h1 =
TensorHandle::CreateLocalHandle(std::move(t1), device0, device0, ctx);
// Create two remote TensorHandles
TensorHandle* h2 = TensorHandle::CreateLazyRemoteHandle(
/*op_id=*/1, /*output_num=*/0, dtype, device1, ctx);
TensorHandle* h3 = TensorHandle::CreateLazyRemoteHandle(
/*op_id=*/2, /*output_num=*/1, dtype, device1, ctx);
// Create a packed TensorHandle
TensorHandle* packed_h = nullptr;
TF_ASSERT_OK(TensorHandle::CreatePackedHandle({h1, h2}, ctx, &packed_h));
// LOCAL, PACKED, REMOTE
absl::InlinedVector<TensorHandle*, 4> inputs = {h0, packed_h, h3};
std::vector<Device*> input_devices;
for (auto* h : inputs) {
input_devices.push_back(absl::get<Device*>(h->DeviceOrHostCPU(*ctx)));
}
const core::RefCountPtr<KernelAndDevice> kernel(
new TestKernelAndDeviceFunc(std::move(input_devices), device0));
ExecuteNodeArgs args(inputs.size());
TF_EXPECT_OK(args.Init(ctx, inputs, kernel));
EXPECT_TRUE(args.HasRemoteOrPackedInputs());
Tensor local0;
TF_EXPECT_OK(args.GetLocalArg(FunctionArgIndex(0), &local0));
EXPECT_EQ(local0.flat<float>().size(), 1);
EXPECT_EQ(local0.flat<float>()(0), 1.0);
Tensor local1;
TF_EXPECT_OK(args.GetLocalArg(FunctionArgIndex(1, 0), &local1));
EXPECT_EQ(local1.flat<float>().size(), 1);
EXPECT_EQ(local1.flat<float>()(0), 2.0);
eager::RemoteTensorHandle remote0;
TF_EXPECT_OK(args.GetRemoteArg(FunctionArgIndex(1, 1), &remote0));
EXPECT_EQ(remote0.op_id(), 1);
EXPECT_EQ(remote0.output_num(), 0);
eager::RemoteTensorHandle remote1;
TF_EXPECT_OK(args.GetRemoteArg(FunctionArgIndex(2), &remote1));
EXPECT_EQ(remote1.op_id(), 2);
EXPECT_EQ(remote1.output_num(), 1);
h0->Unref();
h1->Unref();
h2->Unref();
h3->Unref();
packed_h->Unref();
ctx->Unref();
}
} // namespace
} // namespace tensorflow

View File

@ -158,6 +158,7 @@ Status KernelAndDeviceFunc::InstantiateFunc(const NodeDef& ndef,
for (const Device* device : input_devices_) {
options.input_devices.push_back(device->name());
}
options.composite_devices = composite_devices_;
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
const auto& it = ndef.attr().find("executor_type");
@ -425,7 +426,9 @@ Device* KernelAndDeviceOp::InputDevice(int i) const {
}
Device* KernelAndDeviceFunc::InputDevice(int i) const {
if (input_dtypes_[i] == DT_RESOURCE) {
if ((input_dtypes_[i] == DT_RESOURCE) &&
(composite_devices_.find(input_devices_[i]->name()) ==
composite_devices_.end())) {
return host_cpu_device_;
} else {
return input_devices_[i];

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
@ -241,7 +242,7 @@ class KernelAndDeviceOp final : public KernelAndDevice {
// Represents a multi-device function. Functions can also be run using
// various function-calling kernels including CallOp and PartitionedCallOp.
// In such cases, KernelAndDeviceOp is used.
class KernelAndDeviceFunc final : public KernelAndDevice {
class KernelAndDeviceFunc : public KernelAndDevice {
public:
// `flr` can be nullptr.
// `pflr` must not be nullptr.
@ -249,6 +250,7 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
KernelAndDeviceFunc(
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
std::vector<Device*> input_devices,
absl::flat_hash_map<string, const std::vector<string>*> composite_devices,
std::unordered_map<int, DtypeAndPartialTensorShape>
input_resource_dtypes_and_shapes,
std::function<void(std::function<void()>)>* runner,
@ -261,6 +263,7 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
pflr_(pflr),
handle_(kInvalidHandle),
input_devices_(std::move(input_devices)),
composite_devices_(std::move(composite_devices)),
input_resource_dtypes_and_shapes_(
std::move(input_resource_dtypes_and_shapes)),
name_(name),
@ -320,6 +323,8 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
// CPU devices are not null. Resource handles' devices are actual backing
// devices.
std::vector<Device*> input_devices_;
// Maps from a CompositeDevice name to a list of physical device names.
absl::flat_hash_map<string, const std::vector<string>*> composite_devices_;
std::unordered_map<int, DtypeAndPartialTensorShape>
input_resource_dtypes_and_shapes_;

View File

@ -124,6 +124,10 @@ string TensorHandle::PackedTensorHandleData::DebugString() const {
return debug_str;
}
int TensorHandle::PackedTensorHandleData::NumPackedHandles() const {
return handles_.size();
}
Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
const int index, TensorHandle** handle) const {
if (index < 0 || index >= handles_.size()) {
@ -185,6 +189,13 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
return GetResourceHandleInfoImpl(get_resource_info);
}
int TensorHandle::NumPackedHandles() const {
if (Type() != PACKED) {
return 0;
}
return absl::get<PackedTensorHandleData>(data_).NumPackedHandles();
}
Status TensorHandle::ExtractPackedHandle(const int index,
TensorHandle** handle) const {
if (Type() != PACKED) {
@ -315,8 +326,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
return errors::InvalidArgument(
"CustomDevice is not supported for packing.");
} else {
devices.push_back(
absl::get<Device*>(handle->DeviceOrHostCPU(*ctx))->name());
devices.push_back(handle->op_device() ? handle->op_device()->name()
: ctx->HostCPU()->name());
}
}

View File

@ -231,6 +231,8 @@ class TensorHandle : public AbstractTensorHandleInterface,
std::vector<DtypeAndPartialTensorShape>* result);
Status GetResourceAllowedDevices(std::vector<string>* result);
// Returns the number of packed handles. 0 if the handle type is not PACKED.
int NumPackedHandles() const;
// It's called on a packed TensorHandle. Extract a handle with the given
// index.
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
@ -316,6 +318,8 @@ class TensorHandle : public AbstractTensorHandleInterface,
void Poison(Status status);
string DebugString() const;
// Number of packed handles.
int NumPackedHandles() const;
// Extract a handle on the given index.
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;

View File

@ -164,6 +164,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
h2->Unref();
h3->Unref();
EXPECT_EQ(packed_handle->NumPackedHandles(), 4);
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
EXPECT_EQ(packed_handle->dtype, dtype);
TensorShape packed_shape;
@ -185,7 +186,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
const std::vector<TensorHandle::HandleType> expected_handle_types = {
TensorHandle::LOCAL, TensorHandle::LOCAL, TensorHandle::REMOTE,
TensorHandle::REMOTE};
for (int i = 0; i < 4; ++i) {
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
TensorHandle* h = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));

View File

@ -195,6 +195,9 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
for (Node* n : graph->op_nodes()) {
if (composite_device_names.find(n->assigned_device_name()) !=
composite_device_names.end()) {
// TODO(b/145922293): Validate that an _Arg node assigned to a
// CompositeDevice should have an attribute indicating that the _Arg node
// represents a packed input.
composite_device_to_cluster_nodes[n->assigned_device_name()].push_back(n);
}
}

View File

@ -728,7 +728,9 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
const int64 op_id = 2;
kernel.reset(new KernelAndDeviceFunc(
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
flr, eager_pflr_.get(), std::move(input_dev_ptrs),
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
/*runner=*/nullptr,
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
[=]() { return op_id; }));
@ -773,7 +775,9 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
const int64 op_id = 2;
kernel.reset(new KernelAndDeviceFunc(
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
flr, eager_pflr_.get(), std::move(input_dev_ptrs),
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
/*runner=*/nullptr,
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
[=]() { return op_id; }));