Support TensorProtos as Operation inputs, in order to support remote inputs passed as Tensors to EagerClusterFunctionLibraryRuntime::Run.
PiperOrigin-RevId: 299972015 Change-Id: I73f4da2aec63e64d32dd16a48a7730a362e64543
This commit is contained in:
parent
de037760b7
commit
2031f7aeeb
tensorflow/core
distributed_runtime
protobuf
@ -112,6 +112,7 @@ cc_library(
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:process_function_library_runtime",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime:message_wrappers",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
@ -149,6 +150,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
|
||||
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
@ -359,6 +360,34 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||
{
|
||||
profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
|
||||
profiler::TraceMeLevel::kVerbose);
|
||||
if (!operation.op_inputs().empty() && !operation.inputs().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Both operation.inputs and operation.op_inputs are specified in the "
|
||||
"same request.");
|
||||
}
|
||||
for (const auto& input : operation.op_inputs()) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
if (input.has_remote_handle()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
||||
input.remote_handle(), &handle));
|
||||
op->AddInput(handle);
|
||||
} else {
|
||||
Tensor tensor;
|
||||
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto: ",
|
||||
input.tensor().DebugString());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
||||
std::move(tensor), nullptr, nullptr, eager_context, &handle));
|
||||
op->AddInput(handle);
|
||||
}
|
||||
}
|
||||
// Unref handle since it has a ref as an input now.
|
||||
handle->Unref();
|
||||
}
|
||||
// TODO(b/150963957): Remove this once the migration from operation.inputs
|
||||
// to operation.op_inputs completes.
|
||||
for (const auto& remote_handle : operation.inputs()) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
|
||||
@ -172,7 +173,8 @@ void SetTensorProto(TensorProto* tensor_proto) {
|
||||
|
||||
void AddOperationToEnqueueRequest(
|
||||
int64 id, const string& name,
|
||||
const std::vector<std::pair<int64, int32>>& inputs,
|
||||
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
|
||||
inputs,
|
||||
const std::unordered_map<string, AttrValue>& attrs, const string& device,
|
||||
EnqueueRequest* request) {
|
||||
auto* operation = request->add_queue()->mutable_operation();
|
||||
@ -181,12 +183,19 @@ void AddOperationToEnqueueRequest(
|
||||
operation->set_name(name);
|
||||
operation->set_device(device);
|
||||
|
||||
for (const auto& tensor_handle_pair : inputs) {
|
||||
auto* input = operation->add_inputs();
|
||||
input->set_op_id(tensor_handle_pair.first);
|
||||
input->set_output_num(tensor_handle_pair.second);
|
||||
input->set_op_device(device);
|
||||
input->set_device(device);
|
||||
for (const auto& input : inputs) {
|
||||
if (input.index() == 0) {
|
||||
*operation->add_op_inputs()->mutable_tensor() =
|
||||
absl::get<TensorProto>(input);
|
||||
} else {
|
||||
const auto& tensor_handle_pair =
|
||||
absl::get<std::pair<int64, int32>>(input);
|
||||
auto* input = operation->add_op_inputs()->mutable_remote_handle();
|
||||
input->set_op_id(tensor_handle_pair.first);
|
||||
input->set_output_num(tensor_handle_pair.second);
|
||||
input->set_op_device(device);
|
||||
input->set_device(device);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& attr_entry : attrs) {
|
||||
@ -323,9 +332,9 @@ TEST_F(EagerServiceImplTest, BasicTest) {
|
||||
attrs.insert({"transpose_a", val});
|
||||
attrs.insert({"transpose_b", val});
|
||||
|
||||
AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
AddOperationToEnqueueRequest(
|
||||
2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
|
||||
&remote_enqueue_response));
|
||||
@ -367,7 +376,8 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
|
||||
|
||||
// Creates a context and attempts to execute a function.
|
||||
void TestFunction(const RegisterFunctionOp& register_op,
|
||||
const string& function_name) {
|
||||
const string& function_name,
|
||||
const bool local_inputs = false) {
|
||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||
|
||||
uint64 context_id = random::New64();
|
||||
@ -392,22 +402,35 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
|
||||
remote_enqueue_request.set_context_id(context_id);
|
||||
EnqueueResponse remote_enqueue_response;
|
||||
|
||||
std::unordered_map<string, AttrValue> const_attrs;
|
||||
AttrValue val;
|
||||
val.set_type(tensorflow::DataType::DT_FLOAT);
|
||||
const_attrs.insert({"dtype", val});
|
||||
val.Clear();
|
||||
if (local_inputs) {
|
||||
TensorProto tensor_proto;
|
||||
SetTensorProto(&tensor_proto);
|
||||
AddOperationToEnqueueRequest(
|
||||
2, function_name, {tensor_proto},
|
||||
std::unordered_map<string, AttrValue>(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
|
||||
SetTensorProto(val.mutable_tensor());
|
||||
const_attrs.insert({"value", val});
|
||||
} else {
|
||||
std::unordered_map<string, AttrValue> const_attrs;
|
||||
AttrValue val;
|
||||
val.set_type(tensorflow::DataType::DT_FLOAT);
|
||||
const_attrs.insert({"dtype", val});
|
||||
val.Clear();
|
||||
|
||||
AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
AddOperationToEnqueueRequest(2, function_name, {{1, 0}},
|
||||
std::unordered_map<string, AttrValue>(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
SetTensorProto(val.mutable_tensor());
|
||||
const_attrs.insert({"value", val});
|
||||
|
||||
AddOperationToEnqueueRequest(
|
||||
1, "Const", {}, const_attrs,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
AddOperationToEnqueueRequest(
|
||||
2, function_name, {std::make_pair(1, 0)},
|
||||
std::unordered_map<string, AttrValue>(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
|
||||
&remote_enqueue_response));
|
||||
@ -441,6 +464,12 @@ TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
|
||||
TestFunction(register_op, "MatMulFunction");
|
||||
}
|
||||
|
||||
TEST_F(EagerServiceImplFunctionTest, FunctionWithLocalInputsTest) {
|
||||
RegisterFunctionOp register_op;
|
||||
*register_op.mutable_function_def() = MatMulFunction();
|
||||
TestFunction(register_op, "MatMulFunction", /*local_inputs=*/true);
|
||||
}
|
||||
|
||||
TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
|
||||
RegisterFunctionOp register_op;
|
||||
*register_op.mutable_function_def() = MatMulNestedFunction();
|
||||
@ -526,8 +555,8 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
||||
fdef_ = MatMulFunction();
|
||||
TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
|
||||
eager_pflr_ = absl::make_unique<EagerProcessFunctionLibraryRuntime>(
|
||||
remote_device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
|
||||
remote_device_mgr_.get(), Env::Default(), /*config=*/
|
||||
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
|
||||
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
|
||||
}
|
||||
|
||||
@ -699,9 +728,9 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
|
||||
attrs.insert({"transpose_a", val});
|
||||
attrs.insert({"transpose_b", val});
|
||||
|
||||
AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
&remote_enqueue_request);
|
||||
AddOperationToEnqueueRequest(
|
||||
2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
|
||||
&remote_enqueue_response));
|
||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
|
||||
Tensor* out_tensor) {
|
||||
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
|
||||
@ -37,8 +35,6 @@ bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const string& InMemoryRunStepRequest::session_handle() const {
|
||||
return session_handle_;
|
||||
}
|
||||
|
@ -737,6 +737,9 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
|
||||
RunStepResponse* response_; // Not owned.
|
||||
};
|
||||
|
||||
bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
|
||||
Tensor* out_tensor);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
|
||||
|
@ -25,6 +25,15 @@ message Operation {
|
||||
string name = 2;
|
||||
repeated RemoteTensorHandle inputs = 3;
|
||||
|
||||
message Input {
|
||||
oneof item {
|
||||
RemoteTensorHandle remote_handle = 1;
|
||||
TensorProto tensor = 2;
|
||||
}
|
||||
}
|
||||
|
||||
repeated Input op_inputs = 10;
|
||||
|
||||
// Control Operation IDs that will be respected when ops are re-ordered by
|
||||
// async execution. If async execution (+ op re-ordering) is not enabled, this
|
||||
// should have no effect.
|
||||
|
Loading…
Reference in New Issue
Block a user