Support TensorProtos as Operation inputs, in order to support remote inputs passed as Tensors to EagerClusterFunctionLibraryRuntime::Run.

PiperOrigin-RevId: 299972015
Change-Id: I73f4da2aec63e64d32dd16a48a7730a362e64543
This commit is contained in:
Yujing Zhang 2020-03-09 17:07:02 -07:00 committed by TensorFlower Gardener
parent de037760b7
commit 2031f7aeeb
6 changed files with 102 additions and 34 deletions

View File

@ -112,6 +112,7 @@ cc_library(
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:process_function_library_runtime",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:message_wrappers",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_cache",
@ -149,6 +150,7 @@ tf_cc_test(
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
@ -359,6 +360,34 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
{
profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
profiler::TraceMeLevel::kVerbose);
if (!operation.op_inputs().empty() && !operation.inputs().empty()) {
return errors::InvalidArgument(
"Both operation.inputs and operation.op_inputs are specified in the "
"same request.");
}
for (const auto& input : operation.op_inputs()) {
tensorflow::TensorHandle* handle;
if (input.has_remote_handle()) {
TF_RETURN_IF_ERROR(
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
input.remote_handle(), &handle));
op->AddInput(handle);
} else {
Tensor tensor;
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
return errors::InvalidArgument("Invalid TensorProto: ",
input.tensor().DebugString());
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(tensor), nullptr, nullptr, eager_context, &handle));
op->AddInput(handle);
}
}
// Unref handle since it has a ref as an input now.
handle->Unref();
}
// TODO(b/150963957): Remove this once the migration from operation.inputs
// to operation.op_inputs completes.
for (const auto& remote_handle : operation.inputs()) {
tensorflow::TensorHandle* handle;
TF_RETURN_IF_ERROR(

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
@ -172,7 +173,8 @@ void SetTensorProto(TensorProto* tensor_proto) {
void AddOperationToEnqueueRequest(
int64 id, const string& name,
const std::vector<std::pair<int64, int32>>& inputs,
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
inputs,
const std::unordered_map<string, AttrValue>& attrs, const string& device,
EnqueueRequest* request) {
auto* operation = request->add_queue()->mutable_operation();
@ -181,12 +183,19 @@ void AddOperationToEnqueueRequest(
operation->set_name(name);
operation->set_device(device);
for (const auto& tensor_handle_pair : inputs) {
auto* input = operation->add_inputs();
input->set_op_id(tensor_handle_pair.first);
input->set_output_num(tensor_handle_pair.second);
input->set_op_device(device);
input->set_device(device);
for (const auto& input : inputs) {
if (input.index() == 0) {
*operation->add_op_inputs()->mutable_tensor() =
absl::get<TensorProto>(input);
} else {
const auto& tensor_handle_pair =
absl::get<std::pair<int64, int32>>(input);
auto* input = operation->add_op_inputs()->mutable_remote_handle();
input->set_op_id(tensor_handle_pair.first);
input->set_output_num(tensor_handle_pair.second);
input->set_op_device(device);
input->set_device(device);
}
}
for (const auto& attr_entry : attrs) {
@ -323,9 +332,9 @@ TEST_F(EagerServiceImplTest, BasicTest) {
attrs.insert({"transpose_a", val});
attrs.insert({"transpose_b", val});
AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs,
"/job:localhost/replica:0/task:0/device:CPU:0",
&remote_enqueue_request);
AddOperationToEnqueueRequest(
2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
"/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
@ -367,7 +376,8 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
// Creates a context and attempts to execute a function.
void TestFunction(const RegisterFunctionOp& register_op,
const string& function_name) {
const string& function_name,
const bool local_inputs = false) {
TestEagerServiceImpl eager_service_impl(&worker_env_);
uint64 context_id = random::New64();
@ -392,22 +402,35 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
remote_enqueue_request.set_context_id(context_id);
EnqueueResponse remote_enqueue_response;
std::unordered_map<string, AttrValue> const_attrs;
AttrValue val;
val.set_type(tensorflow::DataType::DT_FLOAT);
const_attrs.insert({"dtype", val});
val.Clear();
if (local_inputs) {
TensorProto tensor_proto;
SetTensorProto(&tensor_proto);
AddOperationToEnqueueRequest(
2, function_name, {tensor_proto},
std::unordered_map<string, AttrValue>(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&remote_enqueue_request);
SetTensorProto(val.mutable_tensor());
const_attrs.insert({"value", val});
} else {
std::unordered_map<string, AttrValue> const_attrs;
AttrValue val;
val.set_type(tensorflow::DataType::DT_FLOAT);
const_attrs.insert({"dtype", val});
val.Clear();
AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
"/job:localhost/replica:0/task:0/device:CPU:0",
&remote_enqueue_request);
AddOperationToEnqueueRequest(2, function_name, {{1, 0}},
std::unordered_map<string, AttrValue>(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&remote_enqueue_request);
SetTensorProto(val.mutable_tensor());
const_attrs.insert({"value", val});
AddOperationToEnqueueRequest(
1, "Const", {}, const_attrs,
"/job:localhost/replica:0/task:0/device:CPU:0",
&remote_enqueue_request);
AddOperationToEnqueueRequest(
2, function_name, {std::make_pair(1, 0)},
std::unordered_map<string, AttrValue>(),
"/job:localhost/replica:0/task:0/device:CPU:0",
&remote_enqueue_request);
}
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
@ -441,6 +464,12 @@ TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
TestFunction(register_op, "MatMulFunction");
}
TEST_F(EagerServiceImplFunctionTest, FunctionWithLocalInputsTest) {
RegisterFunctionOp register_op;
*register_op.mutable_function_def() = MatMulFunction();
TestFunction(register_op, "MatMulFunction", /*local_inputs=*/true);
}
TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
RegisterFunctionOp register_op;
*register_op.mutable_function_def() = MatMulNestedFunction();
@ -526,8 +555,8 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
fdef_ = MatMulFunction();
TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
eager_pflr_ = absl::make_unique<EagerProcessFunctionLibraryRuntime>(
remote_device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
remote_device_mgr_.get(), Env::Default(), /*config=*/
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
}
@ -699,9 +728,9 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
attrs.insert({"transpose_a", val});
attrs.insert({"transpose_b", val});
AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs,
"/job:localhost/replica:0/task:0/device:CPU:0",
&remote_enqueue_request);
AddOperationToEnqueueRequest(
2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
"/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));

View File

@ -23,8 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace {
bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
Tensor* out_tensor) {
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
@ -37,8 +35,6 @@ bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
return false;
}
} // namespace
const string& InMemoryRunStepRequest::session_handle() const {
return session_handle_;
}

View File

@ -737,6 +737,9 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
RunStepResponse* response_; // Not owned.
};
bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
Tensor* out_tensor);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_

View File

@ -25,6 +25,15 @@ message Operation {
string name = 2;
repeated RemoteTensorHandle inputs = 3;
message Input {
oneof item {
RemoteTensorHandle remote_handle = 1;
TensorProto tensor = 2;
}
}
repeated Input op_inputs = 10;
// Control Operation IDs that will be respected when ops are re-ordered by
// async execution. If async execution (+ op re-ordering) is not enabled, this
// should have no effect.