diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index a025186dd81..ec880276e97 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -357,32 +357,10 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, return Status::OK(); } -// There are a lot of references to devices in this function and around. -// Here is what they mean: -// EagerOperation::Device(): The device on which the user requested the op -// be executed, except if we had to change the device due to resource inputs -// or CPU pinning. If the user did not request a device, the op does not -// take resources, and we did not pin it to CPU, the device can be nullptr. -// KernelAndDevice::Device(): The first time we see an op (combined with -// its attributes), we need to create a KernelAndDevice object for it. -// If op->Device() is a nullptr, we select a device for the op when -// creating the KernelAndDevice. A concrete device will always be selected -// here except when `op` is a function to be executed using function library -// runtime. In this case, we don't select a device because running -// a function with explicitly requested device has different behavior than -// running without an explicitly requested device. -Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, - int* num_retvals) { - ScopedMemoryDebugAnnotation op_annotation( - op->op_name(), op->remote_func_params().has_value() - ? op->remote_func_params().value().step_id.value_or(0) - : 0); - profiler::TraceMe activity( - [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); }, - profiler::TraceMeLevel::kInfo); +Status GetOrCreateKernelAndDevice( + EagerOperation* op, TensorHandle** retvals, int* num_retvals, + core::RefCountPtr* out_kernel) { EagerContext& ctx = op->EagerContext(); - auto& executor = op->Executor(); - TF_RETURN_IF_ERROR(executor.status()); Device* device = absl::get(op->Device()); Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName()); @@ -416,9 +394,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, TensorHandle* input = op->Inputs()[i]; if (!ctx.LazyCopyFunctionRemoteInputs() && input->IsRemote()) { TensorHandle* handle = nullptr; - TF_RETURN_IF_ERROR(EagerCopyToDevice( - input, &ctx, &executor, device == nullptr ? ctx.HostCPU() : device, - /* mirror= */ true, &handle)); + TF_RETURN_IF_ERROR( + EagerCopyToDevice(input, &ctx, &op->Executor(), + device == nullptr ? ctx.HostCPU() : device, + /*mirror=*/true, &handle)); op->UpdateInput(i, handle); // Unref handle since it has a ref as an input now handle->Unref(); @@ -569,6 +548,42 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, } } } + kernel->Ref(); // Ownership of reference is passed to out_kernel. + out_kernel->reset(kernel.get()); + return Status::OK(); +} + +// There are a lot of references to devices in this function and around. +// Here is what they mean: +// EagerOperation::Device(): The device on which the user requested the op +// be executed, except if we had to change the device due to resource inputs +// or CPU pinning. If the user did not request a device, the op does not +// take resources, and we did not pin it to CPU, the device can be nullptr. +// KernelAndDevice::Device(): The first time we see an op (combined with +// its attributes), we need to create a KernelAndDevice object for it. +// If op->Device() is a nullptr, we select a device for the op when +// creating the KernelAndDevice. A concrete device will always be selected +// here except when `op` is a function to be executed using function library +// runtime. In this case, we don't select a device because running +// a function with explicitly requested device has different behavior than +// running without an explicitly requested device. +Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, + int* num_retvals) { + ScopedMemoryDebugAnnotation op_annotation( + op->op_name(), op->remote_func_params().has_value() + ? op->remote_func_params().value().step_id.value_or(0) + : 0); + profiler::TraceMe activity( + [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); }, + profiler::TraceMeLevel::kInfo); + EagerContext& ctx = op->EagerContext(); + auto& executor = op->Executor(); + TF_RETURN_IF_ERROR(executor.status()); + + core::RefCountPtr kernel; + TF_RETURN_IF_ERROR( + GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel)); + int num_outputs = kernel->num_outputs(); if (num_outputs > *num_retvals) { return errors::InvalidArgument("Expecting ", num_outputs, @@ -986,6 +1001,61 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { return Status::OK(); } + +Status GetKernelOutputs(std::vector* outputs, int num_outputs, + TensorHandle** retvals, EagerContext* ctx, + KernelAndDevice* kernel) { + for (int i = 0; i < num_outputs; ++i) { + if (retvals[i] == nullptr) { + retvals[i] = TensorHandle::CreateLocalHandle( + std::move((*outputs)[i]), + /* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)), + /* op_device= */ kernel->device(), + /* resource_device= */ kernel->OutputResourceDevice(i), ctx); + } else { + if (TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) { + return errors::Internal( + "Kernel output tensor handle has a different op device than the " + "kernel. This should never happen."); + } + if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) != + absl::get(retvals[i]->device()))) { + return errors::Internal( + "Kernel output tensor handle locates on a different device than " + "the specified kernel output device. This should never happen."); + } + + TF_RETURN_IF_ERROR( + retvals[i]->SetTensor(std::move((*outputs)[i]), + ctx->CanonicalDevice(kernel->OutputDevice(i)))); + } + } + return Status::OK(); +} + +void CollectGraphs(EagerContext* ctx) { + mutex_lock ml(*ctx->MetadataMu()); + + GraphCollector* collector = ctx->GetGraphCollector(); + mutex_lock mll(collector->mu); + + // Adding to partition graphs for backward compatibility. + for (const auto& graph : collector->partitioned_graphs) { + *ctx->RunMetadataProto()->add_partition_graphs() = graph; + } + + if (collector->dirty) { + auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs(); + *function_graphs->mutable_post_optimization_graph() = + collector->optimized_graph; + *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph; + for (const auto& graph : collector->partitioned_graphs) { + *function_graphs->add_partition_graphs() = graph; + } + } + + collector->ClearGraphs(); +} } // namespace Status EagerExecute(EagerOperation* op, TensorHandle** retvals, @@ -1061,50 +1131,18 @@ Status EagerKernelExecute( TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs, cancellation_manager, remote_func_params)); if (graph_collector != nullptr) { - mutex_lock ml(*ctx->MetadataMu()); - { - GraphCollector* collector = ctx->GetGraphCollector(); - mutex_lock mll(collector->mu); - - // Adding to partition graphs for backward compatibility. - for (const auto& graph : collector->partitioned_graphs) { - *ctx->RunMetadataProto()->add_partition_graphs() = graph; - } - - if (collector->dirty) { - auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs(); - *function_graphs->mutable_post_optimization_graph() = - collector->optimized_graph; - *function_graphs->mutable_pre_optimization_graph() = - collector->raw_graph; - for (const auto& graph : collector->partitioned_graphs) { - *function_graphs->add_partition_graphs() = graph; - } - } - - collector->ClearGraphs(); - } + CollectGraphs(ctx); } - DCHECK_EQ(retvals.size(), outputs.size()); - for (int i = 0; i < retvals.size(); ++i) { - if (retvals[i] == nullptr) { - retvals[i] = TensorHandle::CreateLocalHandle( - std::move(outputs[i]), - /* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)), - /* op_device= */ kernel->device(), - /* resource_device= */ kernel->OutputResourceDevice(i), ctx); - } else { - DCHECK_EQ(kernel->device(), retvals[i]->op_device()); - DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)), - absl::get(retvals[i]->device())); - - TF_RETURN_IF_ERROR( - retvals[i]->SetTensor(std::move(outputs[i]), - ctx->CanonicalDevice(kernel->OutputDevice(i)))); - } + if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) { + return errors::Internal( + "EagerKernelExecute returns a list of ", outputs.size(), + " tensors but ", retvals.size(), + " is expected. This should never " + "happen. Please file a bug with the TensorFlow team."); } - return Status::OK(); + return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx, + kernel.get()); } namespace { diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 79a176efbca..d4cd11e12ad 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -44,8 +44,10 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/error_codes.pb.h" @@ -89,6 +91,94 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, return Status::OK(); } + +Status GetEagerOperation(const Operation& operation, + EagerContext* eager_context, + EagerExecutor* eager_executor, + EagerOperation* eager_op) { + const char* name = operation.name().c_str(); // Shorthand + absl::optional remote_func_params = + absl::nullopt; + if (operation.is_function()) { + if (operation.is_component_function()) { + remote_func_params = {operation.id(), operation.func_step_id()}; + } else { + remote_func_params = {operation.id(), absl::nullopt}; + } + } + TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false, + eager_executor, remote_func_params)); + + { + profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal", + profiler::TraceMeLevel::kVerbose); + for (const auto& input : operation.op_inputs()) { + tensorflow::TensorHandle* handle; + if (input.has_remote_handle()) { + TF_RETURN_IF_ERROR( + eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( + input.remote_handle(), &handle)); + TF_RETURN_IF_ERROR(eager_op->AddInput(handle)); + } else { + Tensor tensor; + if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) { + return errors::InvalidArgument("Invalid TensorProto: ", + input.tensor().DebugString()); + } else { + handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr, + nullptr, eager_context); + TF_RETURN_IF_ERROR(eager_op->AddInput(handle)); + } + } + // Unref handle since it has a ref as an input now. + handle->Unref(); + } + } + + for (const auto& attr : operation.attrs()) { + eager_op->MutableAttrs()->Set(attr.first, attr.second); + } + return Status::OK(); +} + +Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { + const tensorflow::Tensor* t = nullptr; + TF_RETURN_IF_ERROR(handle->Tensor(&t)); + t->AsProtoTensorContent(proto); + return Status::OK(); +} + +Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { + const tensorflow::Tensor* t = nullptr; + + // TODO(nareshmodi): This call makes async calls sync calls. Fix this. + TF_RETURN_IF_ERROR(handle->Tensor(&t)); + + t->shape().AsProto(proto); + + return Status::OK(); +} + +Status AddOpRetvalsToResponse( + EagerContext* eager_context, int op_id, int num_retvals, + TensorHandle** retvals, std::function add_tensor_proto_fn, + std::function add_shape_proto_fn) { + if (op_id == kInvalidRemoteOpId) { + // Copy the output tensors back along with the response, since the op id + // is invalid which cannot be added to RemoteMgr. + for (int i = 0; i < num_retvals; i++) { + TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn())); + retvals[i]->Unref(); + } + } else { + eager_context->RemoteMgr()->AddOperationOutputs( + absl::MakeSpan(retvals, num_retvals), op_id); + for (int i = 0; i < num_retvals; i++) { + TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn())); + } + } + return Status::OK(); +} } // namespace Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, @@ -316,72 +406,13 @@ Status EagerServiceImpl::CreateMasterContext( return Status::OK(); } -Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { - const tensorflow::Tensor* t = nullptr; - TF_RETURN_IF_ERROR(handle->Tensor(&t)); - t->AsProtoTensorContent(proto); - return Status::OK(); -} - -Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { - const tensorflow::Tensor* t = nullptr; - - // TODO(nareshmodi): This call makes async calls sync calls. Fix this. - TF_RETURN_IF_ERROR(handle->Tensor(&t)); - - t->shape().AsProto(proto); - - return Status::OK(); -} - Status EagerServiceImpl::ExecuteOp(const Operation& operation, EagerContext* eager_context, EagerExecutor* eager_executor, QueueResponse* queue_response) { - std::unique_ptr op; - const char* name = operation.name().c_str(); // Shorthand - absl::optional remote_func_params = - absl::nullopt; - if (operation.is_function()) { - if (operation.is_component_function()) { - remote_func_params = {operation.id(), operation.func_step_id()}; - } else { - remote_func_params = {operation.id(), absl::nullopt}; - } - } - op.reset(new tensorflow::EagerOperation(eager_context)); - TF_RETURN_IF_ERROR(op->Reset(name, operation.device().c_str(), false, - eager_executor, remote_func_params)); - - { - profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal", - profiler::TraceMeLevel::kVerbose); - for (const auto& input : operation.op_inputs()) { - tensorflow::TensorHandle* handle; - if (input.has_remote_handle()) { - TF_RETURN_IF_ERROR( - eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( - input.remote_handle(), &handle)); - TF_RETURN_IF_ERROR(op->AddInput(handle)); - } else { - Tensor tensor; - if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) { - return errors::InvalidArgument("Invalid TensorProto: ", - input.tensor().DebugString()); - } else { - handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr, - nullptr, eager_context); - TF_RETURN_IF_ERROR(op->AddInput(handle)); - } - } - // Unref handle since it has a ref as an input now. - handle->Unref(); - } - } - - for (const auto& attr : operation.attrs()) { - op->MutableAttrs()->Set(attr.first, attr.second); - } + tensorflow::EagerOperation op(eager_context); + TF_RETURN_IF_ERROR( + GetEagerOperation(operation, eager_context, eager_executor, &op)); int num_retvals = 0; // TODO(nareshmodi): Consider caching this. @@ -390,26 +421,12 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, absl::FixedArray retvals(num_retvals); VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id(); - TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals)); + TF_RETURN_IF_ERROR(EagerExecute(&op, retvals.data(), &num_retvals)); - if (operation.id() == kInvalidRemoteOpId) { - // Copy the output tensors back along with the response, since the op id - // is invalid which cannot be added to RemoteMgr. - for (int i = 0; i < num_retvals; i++) { - TF_RETURN_IF_ERROR( - TensorHandleProto(retvals[i], queue_response->add_tensor())); - retvals[i]->Unref(); - } - } else { - eager_context->RemoteMgr()->AddOperationOutputs( - absl::MakeSpan(retvals.data(), num_retvals), operation.id()); - for (int i = 0; i < num_retvals; i++) { - TF_RETURN_IF_ERROR( - TensorHandleShape(retvals[i], queue_response->add_shape())); - } - } - - return Status::OK(); + return AddOpRetvalsToResponse( + eager_context, operation.id(), num_retvals, retvals.data(), + [queue_response] { return queue_response->add_tensor(); }, + [queue_response] { return queue_response->add_shape(); }); } Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,