Support running a function with packed input handles through C APIs.
Introduce a C API TFE_CreatePackedTensorHandle which creates a TFE_TensorHandle referring to multiple TFE_TensorHandles. PiperOrigin-RevId: 310610230 Change-Id: Icc0ffd5c58ad7780eca38d552c1a2f4617f04891
This commit is contained in:
parent
696f2a8bd7
commit
7e6ea21148
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
@ -638,3 +639,21 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
|||||||
return tensorflow::wrap(
|
return tensorflow::wrap(
|
||||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||||
|
TFE_TensorHandle** handles,
|
||||||
|
int* num_handles,
|
||||||
|
TF_Status* status) {
|
||||||
|
std::vector<tensorflow::TensorHandle*> tensor_handles;
|
||||||
|
tensor_handles.reserve(*num_handles);
|
||||||
|
for (int i = 0; i < *num_handles; ++i) {
|
||||||
|
tensor_handles.push_back(
|
||||||
|
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
|
||||||
|
}
|
||||||
|
tensorflow::EagerContext* context =
|
||||||
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
|
tensorflow::TensorHandle* handle = nullptr;
|
||||||
|
status->status = tensorflow::TensorHandle::CreatePackedHandle(
|
||||||
|
std::move(tensor_handles), context, &handle);
|
||||||
|
return tensorflow::wrap(handle);
|
||||||
|
}
|
||||||
|
@ -541,6 +541,14 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
|||||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||||
|
|
||||||
|
// Create a packed TensorHandle with the given list of TensorHandles.
|
||||||
|
// If `handles` are on the same device, assign the same device to the packed
|
||||||
|
// handle; if `handles` are on different deivces, assign a CompositeDevice to
|
||||||
|
// it.
|
||||||
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||||
|
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -351,6 +351,192 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
|||||||
/*heavy_load_on_streaming_rpc=*/true);
|
/*heavy_load_on_streaming_rpc=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add the values of three variables on three different tasks.
|
||||||
|
string AddVariablesFunction() {
|
||||||
|
tensorflow::FunctionDef def;
|
||||||
|
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||||
|
" signature {"
|
||||||
|
" name: 'AddVariablesFunction'"
|
||||||
|
" input_arg {"
|
||||||
|
" name: 'var'"
|
||||||
|
" type: DT_RESOURCE"
|
||||||
|
" }"
|
||||||
|
" output_arg {"
|
||||||
|
" name: 'sum'"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read0'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read1'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read2'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'add1'"
|
||||||
|
" op: 'Add'"
|
||||||
|
" input: 'read0:value:0'"
|
||||||
|
" input: 'read1:value:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'add2'"
|
||||||
|
" op: 'Add'"
|
||||||
|
" input: 'add1:z:0'"
|
||||||
|
" input: 'read2:value:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" ret {"
|
||||||
|
" key: 'sum'"
|
||||||
|
" value: 'add2:z:0'"
|
||||||
|
" }",
|
||||||
|
&def));
|
||||||
|
return def.SerializeAsString();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestFunctionWithPackedInput) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server1->Start().ok());
|
||||||
|
|
||||||
|
server_def.set_task_index(2);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server2->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||||
|
|
||||||
|
// Create one variable per task.
|
||||||
|
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||||
|
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||||
|
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||||
|
|
||||||
|
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||||
|
int num_replicas = 3;
|
||||||
|
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||||
|
TFE_TensorHandle* packed_handle =
|
||||||
|
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||||
|
|
||||||
|
const string composite_device_name =
|
||||||
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||||
|
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||||
|
composite_device_name);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||||
|
composite_device_name);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
// Register and run a function which returns the sum of 3 variables.
|
||||||
|
const string function_def = AddVariablesFunction();
|
||||||
|
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||||
|
status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(func, packed_handle, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
ASSERT_EQ(1, num_retvals);
|
||||||
|
TFE_DeleteOp(func);
|
||||||
|
TFE_DeleteTensorHandle(packed_handle);
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
float sum = 0;
|
||||||
|
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||||
|
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
EXPECT_EQ(sum, 6.0);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(h0);
|
||||||
|
TFE_DeleteTensorHandle(h1);
|
||||||
|
TFE_DeleteTensorHandle(h2);
|
||||||
|
|
||||||
|
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||||
|
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteExecutor(executor);
|
||||||
|
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server1.release();
|
||||||
|
worker_server2.release();
|
||||||
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
|
@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
|
|||||||
}
|
}
|
||||||
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
||||||
|
|
||||||
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
|
||||||
TF_Status* status) {
|
|
||||||
// Create the variable handle.
|
|
||||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
|
||||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
|
||||||
TFE_OpSetAttrString(op, "container", "", 0);
|
|
||||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_TensorHandle* var_handle = nullptr;
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
|
||||||
TFE_DeleteOp(op);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
CHECK_EQ(1, num_retvals);
|
|
||||||
|
|
||||||
// Assign 'value' to it.
|
|
||||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
|
||||||
TFE_OpAddInput(op, var_handle, status);
|
|
||||||
|
|
||||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
|
||||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
|
||||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
|
||||||
|
|
||||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
||||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
|
||||||
TFE_DeleteTensorHandle);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
|
|
||||||
TFE_OpAddInput(op, value_handle.get(), status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
|
|
||||||
num_retvals = 0;
|
|
||||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
|
||||||
TFE_DeleteOp(op);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
CHECK_EQ(0, num_retvals);
|
|
||||||
|
|
||||||
return var_handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, Variables) {
|
TEST(CAPI, Variables) {
|
||||||
// Variables use resource handles, so this is really a test for resource
|
// Variables use resource handles, so this is really a test for resource
|
||||||
// tensor handling.
|
// tensor handling.
|
||||||
@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||||
@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) {
|
|||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||||
|
@ -133,6 +133,57 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
|||||||
return th;
|
return th;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||||
|
const tensorflow::string& device_name) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
// Create the variable handle.
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||||
|
TFE_OpSetAttrString(op, "container", "", 0);
|
||||||
|
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||||
|
if (!device_name.empty()) {
|
||||||
|
TFE_OpSetDevice(op, device_name.c_str(), status);
|
||||||
|
}
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_TensorHandle* var_handle = nullptr;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
CHECK_EQ(1, num_retvals);
|
||||||
|
|
||||||
|
// Assign 'value' to it.
|
||||||
|
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddInput(op, var_handle, status);
|
||||||
|
|
||||||
|
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||||
|
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||||
|
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||||
|
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||||
|
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||||
|
TFE_DeleteTensorHandle);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
|
||||||
|
TFE_OpAddInput(op, value_handle.get(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
CHECK_EQ(0, num_retvals);
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
return var_handle;
|
||||||
|
}
|
||||||
|
|
||||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
|||||||
// Return a tensor handle containing a 3x2 matrix of floats
|
// Return a tensor handle containing a 3x2 matrix of floats
|
||||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||||
|
|
||||||
|
// Return a variable handle referring to a variable with the given initial value
|
||||||
|
// on the given device.
|
||||||
|
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||||
|
const tensorflow::string& device_name = "");
|
||||||
|
|
||||||
// Return an add op multiplying `a` by `b`.
|
// Return an add op multiplying `a` by `b`.
|
||||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||||
|
|
||||||
|
@ -305,6 +305,7 @@ tf_cuda_library(
|
|||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
":attr_builder",
|
":attr_builder",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
@ -369,6 +370,7 @@ cc_library(
|
|||||||
":eager_operation",
|
":eager_operation",
|
||||||
":kernel_and_device",
|
":kernel_and_device",
|
||||||
":tensor_handle",
|
":tensor_handle",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -396,6 +398,24 @@ cc_library(
|
|||||||
}) + if_mkl([":mkl_eager_op_rewrite"]),
|
}) + if_mkl([":mkl_eager_op_rewrite"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "execute_node_test",
|
||||||
|
srcs = ["execute_node_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":context",
|
||||||
|
":core",
|
||||||
|
":execute",
|
||||||
|
":kernel_and_device",
|
||||||
|
":tensor_handle",
|
||||||
|
"//tensorflow/core:core_cpu_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "mkl_eager_op_rewrite",
|
name = "mkl_eager_op_rewrite",
|
||||||
srcs = ["mkl_eager_op_rewrite.cc"],
|
srcs = ["mkl_eager_op_rewrite.cc"],
|
||||||
@ -466,6 +486,7 @@ cc_library(
|
|||||||
":eager_operation",
|
":eager_operation",
|
||||||
":kernel_and_device",
|
":kernel_and_device",
|
||||||
":tensor_handle",
|
":tensor_handle",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -853,6 +853,18 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EagerContext::FindCompositeDeviceFromName(
|
||||||
|
const char* device_name, CompositeDevice** device) const {
|
||||||
|
tf_shared_lock l(composite_devices_mu_);
|
||||||
|
for (const auto& d : composite_devices_) {
|
||||||
|
if (d.second->name() == device_name) {
|
||||||
|
*device = d.second.get();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors::NotFound("Unknown composite device: ", device_name);
|
||||||
|
}
|
||||||
|
|
||||||
Status EagerContext::FindCustomDeviceFromName(const string& device_name,
|
Status EagerContext::FindCustomDeviceFromName(const string& device_name,
|
||||||
CustomDevice** dev) const {
|
CustomDevice** dev) const {
|
||||||
auto dev_it = custom_devices_.find(device_name);
|
auto dev_it = custom_devices_.find(device_name);
|
||||||
@ -904,8 +916,7 @@ Status EagerContext::FindOrCreateCompositeDevice(
|
|||||||
composite_devices_.size(), &s);
|
composite_devices_.size(), &s);
|
||||||
TF_RETURN_IF_ERROR(s);
|
TF_RETURN_IF_ERROR(s);
|
||||||
*composite_device = device.get();
|
*composite_device = device.get();
|
||||||
// TODO(b/145922293): Add the composite device to the device set of pflr in
|
pflr_->AddCompositeDevice(*composite_device);
|
||||||
// order to make placer recognize it.
|
|
||||||
composite_devices_.emplace(hash_key, std::move(device));
|
composite_devices_.emplace(hash_key, std::move(device));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -483,6 +483,9 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||||||
|
|
||||||
Status FindDeviceFromName(const char* device_name, Device** device) const;
|
Status FindDeviceFromName(const char* device_name, Device** device) const;
|
||||||
|
|
||||||
|
Status FindCompositeDeviceFromName(const char* device_name,
|
||||||
|
CompositeDevice** device) const;
|
||||||
|
|
||||||
Status FindCustomDeviceFromName(const string& device_name,
|
Status FindCustomDeviceFromName(const string& device_name,
|
||||||
CustomDevice** dev) const;
|
CustomDevice** dev) const;
|
||||||
|
|
||||||
|
@ -180,6 +180,10 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
|||||||
&composite_device_0));
|
&composite_device_0));
|
||||||
EXPECT_EQ(composite_device_0->name(),
|
EXPECT_EQ(composite_device_0->name(),
|
||||||
"/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
"/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
||||||
|
CompositeDevice* device = nullptr;
|
||||||
|
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
||||||
|
"/job:worker/replica:0/task:0/device:COMPOSITE:0", &device));
|
||||||
|
EXPECT_EQ(device, composite_device_0);
|
||||||
CompositeDevice* composite_device_1 = nullptr;
|
CompositeDevice* composite_device_1 = nullptr;
|
||||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||||
&composite_device_1));
|
&composite_device_1));
|
||||||
@ -190,6 +194,12 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
|||||||
&composite_device_2));
|
&composite_device_2));
|
||||||
EXPECT_EQ(composite_device_2->name(),
|
EXPECT_EQ(composite_device_2->name(),
|
||||||
"/job:worker/replica:0/task:0/device:COMPOSITE:1");
|
"/job:worker/replica:0/task:0/device:COMPOSITE:1");
|
||||||
|
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
||||||
|
"/job:worker/replica:0/task:0/device:COMPOSITE:1", &device));
|
||||||
|
EXPECT_EQ(device, composite_device_2);
|
||||||
|
|
||||||
|
EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName(
|
||||||
|
"/job:worker/replica:0/task:0/device:COMPOSITE:2", &device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -367,6 +367,7 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
||||||
|
|
||||||
std::vector<Device*> input_dev_ptrs;
|
std::vector<Device*> input_dev_ptrs;
|
||||||
|
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
||||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||||
input_resource_variable_dtypes_and_shapes;
|
input_resource_variable_dtypes_and_shapes;
|
||||||
// We can eliminate some overhead by running simple functions using regular
|
// We can eliminate some overhead by running simple functions using regular
|
||||||
@ -410,6 +411,13 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
Device* input_device;
|
Device* input_device;
|
||||||
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
|
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
|
||||||
input_dev_ptrs.push_back(input_device);
|
input_dev_ptrs.push_back(input_device);
|
||||||
|
CompositeDevice* composite_device = nullptr;
|
||||||
|
if (ctx.FindCompositeDeviceFromName(input_device->name().c_str(),
|
||||||
|
&composite_device)
|
||||||
|
.ok()) {
|
||||||
|
composite_devices[input_device->name()] =
|
||||||
|
composite_device->underlying_devices();
|
||||||
|
}
|
||||||
cache_key =
|
cache_key =
|
||||||
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
|
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
|
||||||
|
|
||||||
@ -520,6 +528,7 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
kernel.reset(new KernelAndDeviceFunc(
|
kernel.reset(new KernelAndDeviceFunc(
|
||||||
flr, ctx.pflr(), std::move(input_dev_ptrs),
|
flr, ctx.pflr(), std::move(input_dev_ptrs),
|
||||||
|
std::move(composite_devices),
|
||||||
std::move(input_resource_variable_dtypes_and_shapes), runner,
|
std::move(input_resource_variable_dtypes_and_shapes), runner,
|
||||||
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
|
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(),
|
||||||
[&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); },
|
[&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); },
|
||||||
|
@ -17,6 +17,51 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
bool ExecuteNodeArgs::IsRemote(EagerContext* ctx, Device* input_device,
|
||||||
|
TensorHandle* handle) {
|
||||||
|
uint64 context_view_id = ctx->GetContextViewId();
|
||||||
|
if (handle->Type() == TensorHandle::REMOTE ||
|
||||||
|
handle->HasRemoteMirror(input_device, context_view_id)) {
|
||||||
|
if (!has_remote_inputs_) {
|
||||||
|
has_remote_inputs_ = true;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
|
Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx,
|
||||||
|
Device* input_device,
|
||||||
|
TensorHandle* packed_handle) {
|
||||||
|
int num_handles = packed_handle->NumPackedHandles();
|
||||||
|
packed_args_.emplace(index, gtl::InlinedVector<TensorValue, 4>(num_handles));
|
||||||
|
TensorValue* packed_arg_flat = &(packed_args_[index][0]);
|
||||||
|
for (int i = 0; i < num_handles; ++i) {
|
||||||
|
TensorHandle* h = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||||
|
// We have validated that h->device() is not a CustomDevice when
|
||||||
|
// constructing a pack TensorHandle.
|
||||||
|
const Status status =
|
||||||
|
h->TensorValue(absl::get<Device*>(h->device()), &packed_arg_flat[i]);
|
||||||
|
if (!status.ok()) {
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
if (IsRemote(ctx, input_device, h)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
if (h->Type() == TensorHandle::PACKED) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Nested packed handles are not supported");
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ExecuteNodeArgs::Init(
|
Status ExecuteNodeArgs::Init(
|
||||||
EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||||
const core::RefCountPtr<KernelAndDevice>& kernel) {
|
const core::RefCountPtr<KernelAndDevice>& kernel) {
|
||||||
@ -35,16 +80,17 @@ Status ExecuteNodeArgs::Init(
|
|||||||
Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
|
Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
uint64 context_view_id = ctx->GetContextViewId();
|
if (IsRemote(ctx, d, in)) {
|
||||||
if (in->Type() == TensorHandle::REMOTE ||
|
|
||||||
in->HasRemoteMirror(d, context_view_id)) {
|
|
||||||
if (!has_remote_inputs_) {
|
|
||||||
has_remote_inputs_ = true;
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
return s;
|
if (in->Type() != TensorHandle::PACKED) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
if (!has_packed_inputs_) {
|
||||||
|
has_packed_inputs_ = true;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(InitPackedHandle(i, ctx, d, in));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,24 +100,44 @@ Status ExecuteNodeArgs::Init(
|
|||||||
serialize_remote_handle_ =
|
serialize_remote_handle_ =
|
||||||
[ctx, &op_inputs](const FunctionArgIndex& index,
|
[ctx, &op_inputs](const FunctionArgIndex& index,
|
||||||
eager::RemoteTensorHandle* handle) -> Status {
|
eager::RemoteTensorHandle* handle) -> Status {
|
||||||
if (index.sub_index >= 0) {
|
TensorHandle* h = op_inputs[index.index];
|
||||||
return errors::InvalidArgument("Got unexpected sub_index ",
|
if (op_inputs[index.index]->Type() == TensorHandle::PACKED) {
|
||||||
index.sub_index, " for argument ",
|
TF_RETURN_IF_ERROR(
|
||||||
index.index);
|
op_inputs[index.index]->ExtractPackedHandle(index.sub_index, &h));
|
||||||
}
|
}
|
||||||
VariantDevice variant_device = op_inputs[index.index]->device();
|
VariantDevice variant_device = h->device();
|
||||||
if (VariantDeviceIsCustom(variant_device)) {
|
if (VariantDeviceIsCustom(variant_device)) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Custom devices and remote execution are currently not supported "
|
"Custom devices and remote execution are currently not supported "
|
||||||
"together.");
|
"together.");
|
||||||
}
|
}
|
||||||
Device* device = absl::get<Device*>(variant_device);
|
Device* device = absl::get<Device*>(variant_device);
|
||||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(h, handle, device,
|
||||||
op_inputs[index.index], handle, device, device->name());
|
device->name());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ExecuteNodeArgs::GetLocalArg(const FunctionArgIndex& index,
|
||||||
|
Tensor* val) const {
|
||||||
|
Status s = EagerKernelArgs::GetLocalArg(index, val);
|
||||||
|
if (s.ok()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (packed_args_.contains(index.index)) {
|
||||||
|
Tensor* arg = packed_args_.at(index.index).at(index.sub_index).tensor;
|
||||||
|
if (arg) {
|
||||||
|
*val = *arg;
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::NotFound("Argument (", index.index, ",", index.sub_index,
|
||||||
|
") has no local tensor.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@ -54,6 +55,8 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
|||||||
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
const absl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||||
const core::RefCountPtr<KernelAndDevice>& kernel);
|
const core::RefCountPtr<KernelAndDevice>& kernel);
|
||||||
|
|
||||||
|
Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
|
||||||
|
|
||||||
bool HasRemoteOrPackedInputs() const override {
|
bool HasRemoteOrPackedInputs() const override {
|
||||||
return has_remote_inputs_ || has_packed_inputs_;
|
return has_remote_inputs_ || has_packed_inputs_;
|
||||||
};
|
};
|
||||||
@ -66,8 +69,20 @@ class ExecuteNodeArgs : public EagerKernelArgs {
|
|||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
// Returns whether `handle` is a remote handle or has a remote mirror on
|
||||||
|
// `input_device`
|
||||||
|
bool IsRemote(EagerContext* ctx, Device* input_device, TensorHandle* handle);
|
||||||
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
|
// Initialize a packed TensorHandle which is the `index`-th argument.
|
||||||
|
Status InitPackedHandle(const int index, EagerContext* ctx,
|
||||||
|
Device* input_device, TensorHandle* packed_handle);
|
||||||
|
|
||||||
bool has_remote_inputs_ = false;
|
bool has_remote_inputs_ = false;
|
||||||
bool has_packed_inputs_ = false;
|
bool has_packed_inputs_ = false;
|
||||||
|
// Maps from the index of a packed arg to a list of sub-args.
|
||||||
|
absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
|
std::function<Status(const FunctionArgIndex&, eager::RemoteTensorHandle*)>
|
||||||
serialize_remote_handle_;
|
serialize_remote_handle_;
|
||||||
|
126
tensorflow/core/common_runtime/eager/execute_node_test.cc
Normal file
126
tensorflow/core/common_runtime/eager/execute_node_test.cc
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/execute_node.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class TestKernelAndDeviceFunc final : public KernelAndDeviceFunc {
|
||||||
|
public:
|
||||||
|
TestKernelAndDeviceFunc(std::vector<Device*> input_devices,
|
||||||
|
Device* host_cpu_device)
|
||||||
|
: KernelAndDeviceFunc(
|
||||||
|
/*flr=*/nullptr, /*pflr=*/nullptr, /*input_devices=*/{},
|
||||||
|
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
|
||||||
|
/*runner=*/nullptr, /*collective_executor=*/nullptr,
|
||||||
|
host_cpu_device, /*name=*/"",
|
||||||
|
/*rendezvous_creator=*/nullptr, /*get_op_id=*/nullptr),
|
||||||
|
test_input_devices_(std::move(input_devices)) {}
|
||||||
|
|
||||||
|
Device* InputDevice(int i) const override { return test_input_devices_[i]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<Device*> test_input_devices_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(ExecuteNodeTest, ExecuteNodeArgs) {
|
||||||
|
StaticDeviceMgr device_mgr(
|
||||||
|
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
|
||||||
|
Device* device0 = device_mgr.ListDevices().at(0);
|
||||||
|
StaticDeviceMgr remote_device_mgr(
|
||||||
|
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
|
||||||
|
Device* device1 = remote_device_mgr.ListDevices().at(0);
|
||||||
|
|
||||||
|
Status s;
|
||||||
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
|
CompositeDevice::MakeDevice({device0->name(), device1->name()},
|
||||||
|
/*unique_device_id=*/0, &s);
|
||||||
|
TF_ASSERT_OK(s);
|
||||||
|
|
||||||
|
auto ctx = new EagerContext(
|
||||||
|
SessionOptions(),
|
||||||
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
|
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||||
|
&device_mgr, false, nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
|
DataType dtype = DT_FLOAT;
|
||||||
|
Tensor t0(dtype, TensorShape({}));
|
||||||
|
// Create two local TensorHandles
|
||||||
|
t0.scalar<float>()() = {1.0f};
|
||||||
|
TensorHandle* h0 =
|
||||||
|
TensorHandle::CreateLocalHandle(std::move(t0), device0, device0, ctx);
|
||||||
|
Tensor t1(dtype, TensorShape({}));
|
||||||
|
t1.scalar<float>()() = {2.0f};
|
||||||
|
TensorHandle* h1 =
|
||||||
|
TensorHandle::CreateLocalHandle(std::move(t1), device0, device0, ctx);
|
||||||
|
// Create two remote TensorHandles
|
||||||
|
TensorHandle* h2 = TensorHandle::CreateLazyRemoteHandle(
|
||||||
|
/*op_id=*/1, /*output_num=*/0, dtype, device1, ctx);
|
||||||
|
TensorHandle* h3 = TensorHandle::CreateLazyRemoteHandle(
|
||||||
|
/*op_id=*/2, /*output_num=*/1, dtype, device1, ctx);
|
||||||
|
// Create a packed TensorHandle
|
||||||
|
TensorHandle* packed_h = nullptr;
|
||||||
|
TF_ASSERT_OK(TensorHandle::CreatePackedHandle({h1, h2}, ctx, &packed_h));
|
||||||
|
|
||||||
|
// LOCAL, PACKED, REMOTE
|
||||||
|
absl::InlinedVector<TensorHandle*, 4> inputs = {h0, packed_h, h3};
|
||||||
|
|
||||||
|
std::vector<Device*> input_devices;
|
||||||
|
for (auto* h : inputs) {
|
||||||
|
input_devices.push_back(absl::get<Device*>(h->DeviceOrHostCPU(*ctx)));
|
||||||
|
}
|
||||||
|
const core::RefCountPtr<KernelAndDevice> kernel(
|
||||||
|
new TestKernelAndDeviceFunc(std::move(input_devices), device0));
|
||||||
|
|
||||||
|
ExecuteNodeArgs args(inputs.size());
|
||||||
|
TF_EXPECT_OK(args.Init(ctx, inputs, kernel));
|
||||||
|
EXPECT_TRUE(args.HasRemoteOrPackedInputs());
|
||||||
|
Tensor local0;
|
||||||
|
TF_EXPECT_OK(args.GetLocalArg(FunctionArgIndex(0), &local0));
|
||||||
|
EXPECT_EQ(local0.flat<float>().size(), 1);
|
||||||
|
EXPECT_EQ(local0.flat<float>()(0), 1.0);
|
||||||
|
Tensor local1;
|
||||||
|
TF_EXPECT_OK(args.GetLocalArg(FunctionArgIndex(1, 0), &local1));
|
||||||
|
EXPECT_EQ(local1.flat<float>().size(), 1);
|
||||||
|
EXPECT_EQ(local1.flat<float>()(0), 2.0);
|
||||||
|
eager::RemoteTensorHandle remote0;
|
||||||
|
TF_EXPECT_OK(args.GetRemoteArg(FunctionArgIndex(1, 1), &remote0));
|
||||||
|
EXPECT_EQ(remote0.op_id(), 1);
|
||||||
|
EXPECT_EQ(remote0.output_num(), 0);
|
||||||
|
eager::RemoteTensorHandle remote1;
|
||||||
|
TF_EXPECT_OK(args.GetRemoteArg(FunctionArgIndex(2), &remote1));
|
||||||
|
EXPECT_EQ(remote1.op_id(), 2);
|
||||||
|
EXPECT_EQ(remote1.output_num(), 1);
|
||||||
|
|
||||||
|
h0->Unref();
|
||||||
|
h1->Unref();
|
||||||
|
h2->Unref();
|
||||||
|
h3->Unref();
|
||||||
|
packed_h->Unref();
|
||||||
|
ctx->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -158,6 +158,7 @@ Status KernelAndDeviceFunc::InstantiateFunc(const NodeDef& ndef,
|
|||||||
for (const Device* device : input_devices_) {
|
for (const Device* device : input_devices_) {
|
||||||
options.input_devices.push_back(device->name());
|
options.input_devices.push_back(device->name());
|
||||||
}
|
}
|
||||||
|
options.composite_devices = composite_devices_;
|
||||||
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
|
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
|
||||||
|
|
||||||
const auto& it = ndef.attr().find("executor_type");
|
const auto& it = ndef.attr().find("executor_type");
|
||||||
@ -425,7 +426,9 @@ Device* KernelAndDeviceOp::InputDevice(int i) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Device* KernelAndDeviceFunc::InputDevice(int i) const {
|
Device* KernelAndDeviceFunc::InputDevice(int i) const {
|
||||||
if (input_dtypes_[i] == DT_RESOURCE) {
|
if ((input_dtypes_[i] == DT_RESOURCE) &&
|
||||||
|
(composite_devices_.find(input_devices_[i]->name()) ==
|
||||||
|
composite_devices_.end())) {
|
||||||
return host_cpu_device_;
|
return host_cpu_device_;
|
||||||
} else {
|
} else {
|
||||||
return input_devices_[i];
|
return input_devices_[i];
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||||
@ -241,7 +242,7 @@ class KernelAndDeviceOp final : public KernelAndDevice {
|
|||||||
// Represents a multi-device function. Functions can also be run using
|
// Represents a multi-device function. Functions can also be run using
|
||||||
// various function-calling kernels including CallOp and PartitionedCallOp.
|
// various function-calling kernels including CallOp and PartitionedCallOp.
|
||||||
// In such cases, KernelAndDeviceOp is used.
|
// In such cases, KernelAndDeviceOp is used.
|
||||||
class KernelAndDeviceFunc final : public KernelAndDevice {
|
class KernelAndDeviceFunc : public KernelAndDevice {
|
||||||
public:
|
public:
|
||||||
// `flr` can be nullptr.
|
// `flr` can be nullptr.
|
||||||
// `pflr` must not be nullptr.
|
// `pflr` must not be nullptr.
|
||||||
@ -249,6 +250,7 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
|||||||
KernelAndDeviceFunc(
|
KernelAndDeviceFunc(
|
||||||
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
|
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
|
||||||
std::vector<Device*> input_devices,
|
std::vector<Device*> input_devices,
|
||||||
|
absl::flat_hash_map<string, const std::vector<string>*> composite_devices,
|
||||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||||
input_resource_dtypes_and_shapes,
|
input_resource_dtypes_and_shapes,
|
||||||
std::function<void(std::function<void()>)>* runner,
|
std::function<void(std::function<void()>)>* runner,
|
||||||
@ -261,6 +263,7 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
|||||||
pflr_(pflr),
|
pflr_(pflr),
|
||||||
handle_(kInvalidHandle),
|
handle_(kInvalidHandle),
|
||||||
input_devices_(std::move(input_devices)),
|
input_devices_(std::move(input_devices)),
|
||||||
|
composite_devices_(std::move(composite_devices)),
|
||||||
input_resource_dtypes_and_shapes_(
|
input_resource_dtypes_and_shapes_(
|
||||||
std::move(input_resource_dtypes_and_shapes)),
|
std::move(input_resource_dtypes_and_shapes)),
|
||||||
name_(name),
|
name_(name),
|
||||||
@ -320,6 +323,8 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
|||||||
// CPU devices are not null. Resource handles' devices are actual backing
|
// CPU devices are not null. Resource handles' devices are actual backing
|
||||||
// devices.
|
// devices.
|
||||||
std::vector<Device*> input_devices_;
|
std::vector<Device*> input_devices_;
|
||||||
|
// Maps from a CompositeDevice name to a list of physical device names.
|
||||||
|
absl::flat_hash_map<string, const std::vector<string>*> composite_devices_;
|
||||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||||
input_resource_dtypes_and_shapes_;
|
input_resource_dtypes_and_shapes_;
|
||||||
|
|
||||||
|
@ -124,6 +124,10 @@ string TensorHandle::PackedTensorHandleData::DebugString() const {
|
|||||||
return debug_str;
|
return debug_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TensorHandle::PackedTensorHandleData::NumPackedHandles() const {
|
||||||
|
return handles_.size();
|
||||||
|
}
|
||||||
|
|
||||||
Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
|
Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
|
||||||
const int index, TensorHandle** handle) const {
|
const int index, TensorHandle** handle) const {
|
||||||
if (index < 0 || index >= handles_.size()) {
|
if (index < 0 || index >= handles_.size()) {
|
||||||
@ -185,6 +189,13 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
|
|||||||
return GetResourceHandleInfoImpl(get_resource_info);
|
return GetResourceHandleInfoImpl(get_resource_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TensorHandle::NumPackedHandles() const {
|
||||||
|
if (Type() != PACKED) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return absl::get<PackedTensorHandleData>(data_).NumPackedHandles();
|
||||||
|
}
|
||||||
|
|
||||||
Status TensorHandle::ExtractPackedHandle(const int index,
|
Status TensorHandle::ExtractPackedHandle(const int index,
|
||||||
TensorHandle** handle) const {
|
TensorHandle** handle) const {
|
||||||
if (Type() != PACKED) {
|
if (Type() != PACKED) {
|
||||||
@ -315,8 +326,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
|||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"CustomDevice is not supported for packing.");
|
"CustomDevice is not supported for packing.");
|
||||||
} else {
|
} else {
|
||||||
devices.push_back(
|
devices.push_back(handle->op_device() ? handle->op_device()->name()
|
||||||
absl::get<Device*>(handle->DeviceOrHostCPU(*ctx))->name());
|
: ctx->HostCPU()->name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,6 +231,8 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
|||||||
std::vector<DtypeAndPartialTensorShape>* result);
|
std::vector<DtypeAndPartialTensorShape>* result);
|
||||||
Status GetResourceAllowedDevices(std::vector<string>* result);
|
Status GetResourceAllowedDevices(std::vector<string>* result);
|
||||||
|
|
||||||
|
// Returns the number of packed handles. 0 if the handle type is not PACKED.
|
||||||
|
int NumPackedHandles() const;
|
||||||
// It's called on a packed TensorHandle. Extract a handle with the given
|
// It's called on a packed TensorHandle. Extract a handle with the given
|
||||||
// index.
|
// index.
|
||||||
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
||||||
@ -316,6 +318,8 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
|||||||
void Poison(Status status);
|
void Poison(Status status);
|
||||||
string DebugString() const;
|
string DebugString() const;
|
||||||
|
|
||||||
|
// Number of packed handles.
|
||||||
|
int NumPackedHandles() const;
|
||||||
// Extract a handle on the given index.
|
// Extract a handle on the given index.
|
||||||
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
||||||
|
|
||||||
|
@ -164,6 +164,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
|||||||
h2->Unref();
|
h2->Unref();
|
||||||
h3->Unref();
|
h3->Unref();
|
||||||
|
|
||||||
|
EXPECT_EQ(packed_handle->NumPackedHandles(), 4);
|
||||||
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||||
EXPECT_EQ(packed_handle->dtype, dtype);
|
EXPECT_EQ(packed_handle->dtype, dtype);
|
||||||
TensorShape packed_shape;
|
TensorShape packed_shape;
|
||||||
@ -185,7 +186,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
|||||||
const std::vector<TensorHandle::HandleType> expected_handle_types = {
|
const std::vector<TensorHandle::HandleType> expected_handle_types = {
|
||||||
TensorHandle::LOCAL, TensorHandle::LOCAL, TensorHandle::REMOTE,
|
TensorHandle::LOCAL, TensorHandle::LOCAL, TensorHandle::REMOTE,
|
||||||
TensorHandle::REMOTE};
|
TensorHandle::REMOTE};
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||||
TensorHandle* h = nullptr;
|
TensorHandle* h = nullptr;
|
||||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
|
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
|
||||||
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));
|
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));
|
||||||
|
@ -195,6 +195,9 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
|||||||
for (Node* n : graph->op_nodes()) {
|
for (Node* n : graph->op_nodes()) {
|
||||||
if (composite_device_names.find(n->assigned_device_name()) !=
|
if (composite_device_names.find(n->assigned_device_name()) !=
|
||||||
composite_device_names.end()) {
|
composite_device_names.end()) {
|
||||||
|
// TODO(b/145922293): Validate that an _Arg node assigned to a
|
||||||
|
// CompositeDevice should have an attribute indicating that the _Arg node
|
||||||
|
// represents a packed input.
|
||||||
composite_device_to_cluster_nodes[n->assigned_device_name()].push_back(n);
|
composite_device_to_cluster_nodes[n->assigned_device_name()].push_back(n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -728,7 +728,9 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
|
|||||||
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
|
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
|
||||||
const int64 op_id = 2;
|
const int64 op_id = 2;
|
||||||
kernel.reset(new KernelAndDeviceFunc(
|
kernel.reset(new KernelAndDeviceFunc(
|
||||||
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
|
flr, eager_pflr_.get(), std::move(input_dev_ptrs),
|
||||||
|
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
|
||||||
|
/*runner=*/nullptr,
|
||||||
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
|
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
|
||||||
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
|
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
|
||||||
[=]() { return op_id; }));
|
[=]() { return op_id; }));
|
||||||
@ -773,7 +775,9 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
|
|||||||
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
|
core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
|
||||||
const int64 op_id = 2;
|
const int64 op_id = 2;
|
||||||
kernel.reset(new KernelAndDeviceFunc(
|
kernel.reset(new KernelAndDeviceFunc(
|
||||||
flr, eager_pflr_.get(), std::move(input_dev_ptrs), {}, /*runner=*/nullptr,
|
flr, eager_pflr_.get(), std::move(input_dev_ptrs),
|
||||||
|
/*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
|
||||||
|
/*runner=*/nullptr,
|
||||||
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
|
/*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
|
||||||
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
|
[ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); },
|
||||||
[=]() { return op_id; }));
|
[=]() { return op_id; }));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user