Support running a function with packed input handles through C APIs.
Introduce a C API TFE_CreatePackedTensorHandle which creates a TFE_TensorHandle referring to multiple TFE_TensorHandles. PiperOrigin-RevId: 310610230 Change-Id: Icc0ffd5c58ad7780eca38d552c1a2f4617f04891
This commit is contained in:
parent
696f2a8bd7
commit
7e6ea21148
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
@ -638,3 +639,21 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
TFE_TensorHandle** handles,
|
||||
int* num_handles,
|
||||
TF_Status* status) {
|
||||
std::vector<tensorflow::TensorHandle*> tensor_handles;
|
||||
tensor_handles.reserve(*num_handles);
|
||||
for (int i = 0; i < *num_handles; ++i) {
|
||||
tensor_handles.push_back(
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::TensorHandle::CreatePackedHandle(
|
||||
std::move(tensor_handles), context, &handle);
|
||||
return tensorflow::wrap(handle);
|
||||
}
|
||||
|
@ -541,6 +541,14 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||
|
||||
// Create a packed TensorHandle with the given list of TensorHandles.
|
||||
// If `handles` are on the same device, assign the same device to the packed
|
||||
// handle; if `handles` are on different deivces, assign a CompositeDevice to
|
||||
// it.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -351,6 +351,192 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
}
|
||||
|
||||
// Add the values of three variables on three different tasks.
|
||||
string AddVariablesFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'AddVariablesFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'sum'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read1'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read2'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add1'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read1:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add2'"
|
||||
" op: 'Add'"
|
||||
" input: 'add1:z:0'"
|
||||
" input: 'read2:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'sum'"
|
||||
" value: 'add2:z:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
// Create one variable per task.
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
TFE_TensorHandle* packed_handle =
|
||||
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||
|
||||
const string composite_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Register and run a function which returns the sum of 3 variables.
|
||||
const string function_def = AddVariablesFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(packed_handle);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(sum, 6.0);
|
||||
|
||||
TFE_DeleteTensorHandle(h0);
|
||||
TFE_DeleteTensorHandle(h1);
|
||||
TFE_DeleteTensorHandle(h2);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
|
@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
}
|
||||
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
||||
|
||||
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
||||
TF_Status* status) {
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TEST(CAPI, Variables) {
|
||||
// Variables use resource handles, so this is really a test for resource
|
||||
// tensor handling.
|
||||
@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
|
@ -133,6 +133,57 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (!device_name.empty()) {
|
||||
TFE_OpSetDevice(op, device_name.c_str(), status);
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
|
@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return a variable handle referring to a variable with the given initial value
|
||||
// on the given device.
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name = "");
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
|
@ -305,6 +305,7 @@ tf_cuda_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":attr_builder",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -369,6 +370,7 @@ cc_library(
|
||||
":eager_operation",
|
||||
":kernel_and_device",
|
||||
":tensor_handle",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -396,6 +398,24 @@ cc_library(
|
||||
}) + if_mkl([":mkl_eager_op_rewrite"]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "execute_node_test",
|
||||
srcs = ["execute_node_test.cc"],
|
||||
deps = [
|
||||
":context",
|
||||
":core",
|
||||
":execute",
|
||||
":kernel_and_device",
|
||||
":tensor_handle",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkl_eager_op_rewrite",
|
||||
srcs = ["mkl_eager_op_rewrite.cc"],
|
||||
@ -466,6 +486,7 @@ cc_library(
|
||||
":eager_operation",
|
||||
":kernel_and_device",
|
||||
":tensor_handle",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -853,6 +853,18 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
|
||||
return status;
|
||||
}
|
||||
|
||||
Status EagerContext::FindCompositeDeviceFromName(
|
||||
const char* device_name, CompositeDevice** device) const {
|
||||
tf_shared_lock l(composite_devices_mu_);
|
||||
for (const auto& d : composite_devices_) {
|
||||
if (d.second->name() == device_name) {
|
||||
*device = d.second.get();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::NotFound("Unknown composite device: ", device_name);
|
||||
}
|
||||
|
||||
Status EagerContext::FindCustomDeviceFromName(const string& device_name,
|
||||
CustomDevice** dev) const {
|
||||
auto dev_it = custom_devices_.find(device_name);
|
||||
@ -904,8 +916,7 @@ Status EagerContext::FindOrCreateCompositeDevice(
|
||||
composite_devices_.size(), &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
*composite_device = device.get();
|
||||
// TODO(b/145922293): Add the composite device to the device set of pflr in
|
||||
// order to make placer recognize it.
|
||||
pflr_->AddCompositeDevice(*composite_device);
|
||||
composite_devices_.emplace(hash_key, std::move(device));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -483,6 +483,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
|
||||
Status FindDeviceFromName(const char* device_name, Device** device) const;
|
||||
|
||||
Status FindCompositeDeviceFromName(const char* device_name,
|
||||
CompositeDevice** device) const;
|
||||
|
||||
Status FindCustomDeviceFromName(const string& device_name,
|
||||
CustomDevice** dev) const;
|
||||
|
||||
|
@ -180,6 +180,10 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
||||
&composite_device_0));
|
||||
EXPECT_EQ(composite_device_0->name(),
|
||||
"/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
||||
CompositeDevice* device = nullptr;
|
||||
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
||||
"/job:worker/replica:0/task:0/device:COMPOSITE:0", &device));
|
||||
EXPECT_EQ(device, composite_device_0);
|
||||
CompositeDevice* composite_device_1 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||
&composite_device_1));
|
||||
@ -190,6 +194,12 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
||||
&composite_device_2));
|
||||
EXPECT_EQ(composite_device_2->name(),
|
||||
"/job:worker/replica:0/task:0/device:COMPOSITE:1");
|
||||
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
||||
"/job:worker/replica:0/task:0/device:COMPOSITE:1", &device));
|
||||
EXPECT_EQ(device, composite_device_2);
|
||||
|
||||
EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName(
|
||||
"/job:worker/replica:0/task:0/device:COMPOSITE:2", &device)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -367,6 +367,7 @@ Status GetOrCreateKernelAndDevice(
|
||||
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
||||
|
||||
std::vector<Device*> input_dev_ptrs;
|
||||
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||
input_resource_variable_dtypes_and_shapes;
|
||||
// We can eliminate some overhead by running simple functions using regular
|
||||
@ -410,6 +411,13 @@ Status GetOrCreateKernelAndDevice(
|
||||
Device* input_device;
|
||||
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
|
||||
input_dev_ptrs.push_back(input_device);
|
||||
CompositeDevice* composite_device = nullptr;
|
||||
if (ctx.FindCompositeDeviceFromName(input_device->name().c_str(),
|
||||
&composite_device)
|
||||
.ok()) {
|
||||
composite_devices[input_device->name()] =
|
||||
composite_device->underlying_devices();
|
||||
}
|
||||
cache_key =
|
||||
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
|
||||
|
||||
@ -520,6 +528,7 @@ Status GetOrCreateKernelAndDevice(
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
kernel.reset(new KernelAndDeviceFunc(
|
||||
flr, ctx.pflr(), std::move(input_dev_ptrs),
|
||||
std::move(composite_devices),
|
||||
std::move(input_resource_variable_dtypes_and_shapes), runner,
|
||||
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
|
||||
[&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); },
|
||||
|
@ -17,6 +17,51 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool ExecuteNodeArgs::IsRemote(EagerContext* ctx, Device* input_device,
|
||||
TensorHandle* handle) {
|
||||
uint64 context_view_id = ctx->GetContextViewId();
|
||||
if (handle->Type() == TensorHandle::REMOTE ||
|
||||
handle->HasRemoteMirror(input_device, context_view_id)) {
|
||||
if (!has_remote_inputs_) {
|
||||
has_remote_inputs_ = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx,
|
||||
Device* input_device,
|
||||
TensorHandle* packed_handle) {
|
||||
int num_handles = packed_handle->NumPackedHandles();
|
||||
packed_args_.emplace(index, gtl::InlinedVector<TensorValue, 4>(num_handles));
|
||||
TensorValue* packed_arg_flat = &(packed_args_[index][0]);
|
||||
for (int i = 0; i < num_handles; ++i) {
|
||||
TensorHandle* h = nullptr;
|
||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||
// We have validated that h->device() is not a CustomDevice when
|
||||
// constructing a pack TensorHandle.
|
||||
const Status status =
|
||||
h->TensorValue(absl::get<Device*>(h->device()), &packed_arg_flat[i]);
|
||||
if (!status.ok()) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (IsRemote(ctx, input_device, h)) {
|
||||
continue;
|
||||
}
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
if (h->Type() == TensorHandle::PACKED) {
|
||||
return errors::InvalidArgument(
|
||||
"Nested packed handles are not supported");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecuteNodeArgs::Init(
|
||||
EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||
const core::RefCountPtr<KernelAndDevice>& kernel) {
|
||||
@ -35,16 +80,17 @@ Status ExecuteNodeArgs::Init(
|
||||
Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
|
||||
if (!s.ok()) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
uint64 context_view_id = ctx->GetContextViewId();
|
||||
if (in->Type() == TensorHandle::REMOTE ||
|
||||
in->HasRemoteMirror(d, context_view_id)) {
|
||||
if (!has_remote_inputs_) {
|
||||
has_remote_inputs_ = true;
|
||||
}
|
||||
if (IsRemote(ctx, d, in)) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
return s;
|
||||
if (in->Type() != TensorHandle::PACKED) {
|
||||
return s;
|
||||
}
|
||||
if (!has_packed_inputs_) {
|
||||
has_packed_inputs_ = true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(InitPackedHandle(i, ctx, d, in));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -54,24 +100,44 @@ Status ExecuteNodeArgs::Init(
|
||||
serialize_remote_handle_ =
|
||||
[ctx, &op_inputs](const FunctionArgIndex& index,
|
||||
eager::RemoteTensorHandle* handle) -> Status {
|
||||
if (index.sub_index >= 0) {
|
||||
return errors::InvalidArgument("Got unexpected sub_index ",
|
||||
index.sub_index, " for argument ",
|
||||
index.index);
|
||||
TensorHandle* h = op_inputs[index.index];
|
||||
if (op_inputs[index.index]->Type() == TensorHandle::PACKED) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
op_inputs[index.index]->ExtractPackedHandle(index.sub_index, &h));
|
||||
}
|
||||
VariantDevice variant_device = op_inputs[index.index]->device();
|
||||
VariantDevice variant_device = h->device();
|
||||
if (VariantDeviceIsCustom(variant_device)) {
|
||||
return errors::Internal(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
"together.");
|
||||
}
|
||||
Device* device = absl::get<Device*>(variant_device);
|
||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
op_inputs[index.index], handle, device, device->name());
|
||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(h, handle, device,
|
||||
device->name());
|
||||
};
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecuteNodeArgs::GetLocalArg(const FunctionArgIndex& index,
|
||||
Tensor* val) const {
|
||||
Status s = EagerKernelArgs::GetLocalArg(index, val);
|
||||
if (s.ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (packed_args_.contains(index.index)) {
|
||||
Tensor* arg = packed_args_.at(index.index).at(index.sub_index).tensor;
|
||||
if (arg) {
|
||||
*val = *arg;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::NotFound("Argument (", index.index, ",", index.sub_index,
|
||||
") has no local tensor.");
|
||||
}
|
||||
} else {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
@ -54,6 +55,8 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
||||
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||
const core::RefCountPtr<KernelAndDevice>& kernel);
|
||||
|
||||
Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
|
||||
|
||||
bool HasRemoteOrPackedInputs() const override {
|
||||
return has_remote_inputs_ || has_packed_inputs_;
|
||||
};
|
||||
@ -66,8 +69,20 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
private:
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// Returns whether `handle` is a remote handle or has a remote mirror on
|
||||
// `input_device`
|
||||
bool IsRemote(EagerContext* ctx, Device* input_device, TensorHandle* handle);
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
// Initialize a packed TensorHandle which is the `index`-th argument.
|
||||
Status InitPackedHandle(const int index, EagerContext* ctx,
|
||||
Device* input_device, TensorHandle* packed_handle);
|
||||
|
||||
bool has_remote_inputs_ = false;
|
||||
bool has_packed_inputs_ = false;
|
||||
// Maps from the index of a packed arg to a list of sub-args.
|
||||
absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
|
||||
serialize_remote_handle_;
|
||||
|
126
tensorflow/core/common_runtime/eager/execute_node_test.cc
Normal file
126
tensorflow/core/common_runtime/eager/execute_node_test.cc
Normal file
@ -0,0 +1,126 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/execute_node.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class TestKernelAndDeviceFunc final : public KernelAndDeviceFunc {
|
||||
public:
|
||||
TestKernelAndDeviceFunc(std::vector<Device*> input_devices,
|
||||
Device* host_cpu_device)
|
||||
: KernelAndDeviceFunc(
|
||||
/*flr=*/nullptr, /*pflr=*/nullptr, /*input_devices=*/{},
|
||||
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
|
||||
/*runner=*/nullptr, /*collective_executor=*/nullptr,
|
||||
host_cpu_device, /*name=*/"",
|
||||
/*rendezvous_creator=*/nullptr, /*get_op_id=*/nullptr),
|
||||
test_input_devices_(std::move(input_devices)) {}
|
||||
|
||||
Device* InputDevice(int i) const override { return test_input_devices_[i]; }
|
||||
|
||||
private:
|
||||
std::vector<Device*> test_input_devices_;
|
||||
};
|
||||
|
||||
TEST(ExecuteNodeTest, ExecuteNodeArgs) {
|
||||
StaticDeviceMgr device_mgr(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
|
||||
Device* device0 = device_mgr.ListDevices().at(0);
|
||||
StaticDeviceMgr remote_device_mgr(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
|
||||
Device* device1 = remote_device_mgr.ListDevices().at(0);
|
||||
|
||||
Status s;
|
||||
std::unique_ptr<CompositeDevice> composite_device =
|
||||
CompositeDevice::MakeDevice({device0->name(), device1->name()},
|
||||
/*unique_device_id=*/0, &s);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
auto ctx = new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||
&device_mgr, false, nullptr, nullptr, nullptr);
|
||||
|
||||
DataType dtype = DT_FLOAT;
|
||||
Tensor t0(dtype, TensorShape({}));
|
||||
// Create two local TensorHandles
|
||||
t0.scalar<float>()() = {1.0f};
|
||||
TensorHandle* h0 =
|
||||
TensorHandle::CreateLocalHandle(std::move(t0), device0, device0, ctx);
|
||||
Tensor t1(dtype, TensorShape({}));
|
||||
t1.scalar<float>()() = {2.0f};
|
||||
TensorHandle* h1 =
|
||||
TensorHandle::CreateLocalHandle(std::move(t1), device0, device0, ctx);
|
||||
// Create two remote TensorHandles
|
||||
TensorHandle* h2 = TensorHandle::CreateLazyRemoteHandle(
|
||||
/*op_id=*/1, /*output_num=*/0, dtype, device1, ctx);
|
||||
TensorHandle* h3 = TensorHandle::CreateLazyRemoteHandle(
|
||||
/*op_id=*/2, /*output_num=*/1, dtype, device1, ctx);
|
||||
// Create a packed TensorHandle
|
||||
TensorHandle* packed_h = nullptr;
|
||||
TF_ASSERT_OK(TensorHandle::CreatePackedHandle({h1, h2}, ctx, &packed_h));
|
||||
|
||||
// LOCAL, PACKED, REMOTE
|
||||
absl::InlinedVector<TensorHandle*, 4> inputs = {h0, packed_h, h3};
|
||||
|
||||
std::vector<Device*> input_devices;
|
||||
for (auto* h : inputs) {
|
||||
input_devices.push_back(absl::get<Device*>(h->DeviceOrHostCPU(*ctx)));
|
||||
}
|
||||
const core::RefCountPtr<KernelAndDevice> kernel(
|
||||
new TestKernelAndDeviceFunc(std::move(input_devices), device0));
|
||||
|
||||
ExecuteNodeArgs args(inputs.size());
|
||||
TF_EXPECT_OK(args.Init(ctx, inputs, kernel));
|
||||
EXPECT_TRUE(args.HasRemoteOrPackedInputs());
|
||||
Tensor local0;
|
||||
TF_EXPECT_OK(args.GetLocalArg(FunctionArgIndex(0), &local0));
|
||||
EXPECT_EQ(local0.flat<float>().size(), 1);
|
||||
EXPECT_EQ(local0.flat<float>()(0), 1.0);
|
||||
Tensor local1;
|
||||
TF_EXPECT_OK(args.GetLocalArg(FunctionArgIndex(1, 0), &local1));
|
||||
EXPECT_EQ(local1.flat<float>().size(), 1);
|
||||
EXPECT_EQ(local1.flat<float>()(0), 2.0);
|
||||
eager::RemoteTensorHandle remote0;
|
||||
TF_EXPECT_OK(args.GetRemoteArg(FunctionArgIndex(1, 1), &remote0));
|
||||
EXPECT_EQ(remote0.op_id(), 1);
|
||||
EXPECT_EQ(remote0.output_num(), 0);
|
||||
eager::RemoteTensorHandle remote1;
|
||||
TF_EXPECT_OK(args.GetRemoteArg(FunctionArgIndex(2), &remote1));
|
||||
EXPECT_EQ(remote1.op_id(), 2);
|
||||
EXPECT_EQ(remote1.output_num(), 1);
|
||||
|
||||
h0->Unref();
|
||||
h1->Unref();
|
||||
h2->Unref();
|
||||
h3->Unref();
|
||||
packed_h->Unref();
|
||||
ctx->Unref();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -158,6 +158,7 @@ Status KernelAndDeviceFunc::InstantiateFunc(const NodeDef& ndef,
|
||||
for (const Device* device : input_devices_) {
|
||||
options.input_devices.push_back(device->name());
|
||||
}
|
||||
options.composite_devices = composite_devices_;
|
||||
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
|
||||
|
||||
const auto& it = ndef.attr().find("executor_type");
|
||||
@ -425,7 +426,9 @@ Device* KernelAndDeviceOp::InputDevice(int i) const {
|
||||
}
|
||||
|
||||
Device* KernelAndDeviceFunc::InputDevice(int i) const {
|
||||
if (input_dtypes_[i] == DT_RESOURCE) {
|
||||
if ((input_dtypes_[i] == DT_RESOURCE) &&
|
||||
(composite_devices_.find(input_devices_[i]->name()) ==
|
||||
composite_devices_.end())) {
|
||||
return host_cpu_device_;
|
||||
} else {
|
||||
return input_devices_[i];
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
@ -241,7 +242,7 @@ class KernelAndDeviceOp final : public KernelAndDevice {
|
||||
// Represents a multi-device function. Functions can also be run using
|
||||
// various function-calling kernels including CallOp and PartitionedCallOp.
|
||||
// In such cases, KernelAndDeviceOp is used.
|
||||
class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
class KernelAndDeviceFunc : public KernelAndDevice {
|
||||
public:
|
||||
// `flr` can be nullptr.
|
||||
// `pflr` must not be nullptr.
|
||||
@ -249,6 +250,7 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
KernelAndDeviceFunc(
|
||||
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
|
||||
std::vector<Device*> input_devices,
|
||||
absl::flat_hash_map<string, const std::vector<string>*> composite_devices,
|
||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||
input_resource_dtypes_and_shapes,
|
||||
std::function<void(std::function<void()>)>* runner,
|
||||
@ -261,6 +263,7 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
pflr_(pflr),
|
||||
handle_(kInvalidHandle),
|
||||
input_devices_(std::move(input_devices)),
|
||||
composite_devices_(std::move(composite_devices)),
|
||||
input_resource_dtypes_and_shapes_(
|
||||
std::move(input_resource_dtypes_and_shapes)),
|
||||
name_(name),
|
||||
@ -320,6 +323,8 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
// CPU devices are not null. Resource handles' devices are actual backing
|
||||
// devices.
|
||||
std::vector<Device*> input_devices_;
|
||||
// Maps from a CompositeDevice name to a list of physical device names.
|
||||
absl::flat_hash_map<string, const std::vector<string>*> composite_devices_;
|
||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||
input_resource_dtypes_and_shapes_;
|
||||
|
||||
|
@ -124,6 +124,10 @@ string TensorHandle::PackedTensorHandleData::DebugString() const {
|
||||
return debug_str;
|
||||
}
|
||||
|
||||
int TensorHandle::PackedTensorHandleData::NumPackedHandles() const {
|
||||
return handles_.size();
|
||||
}
|
||||
|
||||
Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
|
||||
const int index, TensorHandle** handle) const {
|
||||
if (index < 0 || index >= handles_.size()) {
|
||||
@ -185,6 +189,13 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
|
||||
return GetResourceHandleInfoImpl(get_resource_info);
|
||||
}
|
||||
|
||||
int TensorHandle::NumPackedHandles() const {
|
||||
if (Type() != PACKED) {
|
||||
return 0;
|
||||
}
|
||||
return absl::get<PackedTensorHandleData>(data_).NumPackedHandles();
|
||||
}
|
||||
|
||||
Status TensorHandle::ExtractPackedHandle(const int index,
|
||||
TensorHandle** handle) const {
|
||||
if (Type() != PACKED) {
|
||||
@ -315,8 +326,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
return errors::InvalidArgument(
|
||||
"CustomDevice is not supported for packing.");
|
||||
} else {
|
||||
devices.push_back(
|
||||
absl::get<Device*>(handle->DeviceOrHostCPU(*ctx))->name());
|
||||
devices.push_back(handle->op_device() ? handle->op_device()->name()
|
||||
: ctx->HostCPU()->name());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -231,6 +231,8 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
std::vector<DtypeAndPartialTensorShape>* result);
|
||||
Status GetResourceAllowedDevices(std::vector<string>* result);
|
||||
|
||||
// Returns the number of packed handles. 0 if the handle type is not PACKED.
|
||||
int NumPackedHandles() const;
|
||||
// It's called on a packed TensorHandle. Extract a handle with the given
|
||||
// index.
|
||||
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
||||
@ -316,6 +318,8 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
void Poison(Status status);
|
||||
string DebugString() const;
|
||||
|
||||
// Number of packed handles.
|
||||
int NumPackedHandles() const;
|
||||
// Extract a handle on the given index.
|
||||
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
||||
|
||||
|
@ -164,6 +164,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
||||
h2->Unref();
|
||||
h3->Unref();
|
||||
|
||||
EXPECT_EQ(packed_handle->NumPackedHandles(), 4);
|
||||
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||
EXPECT_EQ(packed_handle->dtype, dtype);
|
||||
TensorShape packed_shape;
|
||||
@ -185,7 +186,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
||||
const std::vector<TensorHandle::HandleType> expected_handle_types = {
|
||||
TensorHandle::LOCAL, TensorHandle::LOCAL, TensorHandle::REMOTE,
|
||||
TensorHandle::REMOTE};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||
TensorHandle* h = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
|
||||
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));
|
||||
|
@ -195,6 +195,9 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||
for (Node* n : graph->op_nodes()) {
|
||||
if (composite_device_names.find(n->assigned_device_name()) !=
|
||||
composite_device_names.end()) {
|
||||
// TODO(b/145922293): Validate that an _Arg node assigned to a
|
||||
// CompositeDevice should have an attribute indicating that the _Arg node
|
||||
// represents a packed input.
|
||||
composite_device_to_cluster_nodes[n->assigned_device_name()].push_back(n);
|
||||
}
|
||||
}
|
||||
|
@ -728,7 +728,9 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
|
||||
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
|
||||
const int64 op_id = 2;
|
||||
kernel.reset(new KernelAndDeviceFunc(
|
||||
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
|
||||
flr, eager_pflr_.get(), std::move(input_dev_ptrs),
|
||||
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
|
||||
/*runner=*/nullptr,
|
||||
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
|
||||
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
|
||||
[=]() { return op_id; }));
|
||||
@ -773,7 +775,9 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
|
||||
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
|
||||
const int64 op_id = 2;
|
||||
kernel.reset(new KernelAndDeviceFunc(
|
||||
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
|
||||
flr, eager_pflr_.get(), std::move(input_dev_ptrs),
|
||||
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
|
||||
/*runner=*/nullptr,
|
||||
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
|
||||
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
|
||||
[=]() { return op_id; }));
|
||||
|
Loading…
Reference in New Issue
Block a user