Non-functional change in preparation for introducing non-blocking function execution.

Modularize components in eager execute and eager service implementation. They will be shared and reused in the upcoming async code path.

PiperOrigin-RevId: 308347746
Change-Id: Ida9cca1a1a88d3e6509c61950f4eaa4f18dbe864
This commit is contained in:
Haoyu Zhang 2020-04-24 16:21:59 -07:00 committed by TensorFlower Gardener
parent 5f543e0fa1
commit 47995cbaf7
2 changed files with 205 additions and 150 deletions

View File

@ -357,32 +357,10 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
return Status::OK();
}
// There are a lot of references to devices in this function and around.
// Here is what they mean:
// EagerOperation::Device(): The device on which the user requested the op
// be executed, except if we had to change the device due to resource inputs
// or CPU pinning. If the user did not request a device, the op does not
// take resources, and we did not pin it to CPU, the device can be nullptr.
// KernelAndDevice::Device(): The first time we see an op (combined with
// its attributes), we need to create a KernelAndDevice object for it.
// If op->Device() is a nullptr, we select a device for the op when
// creating the KernelAndDevice. A concrete device will always be selected
// here except when `op` is a function to be executed using function library
// runtime. In this case, we don't select a device because running
// a function with explicitly requested device has different behavior than
// running without an explicitly requested device.
Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
int* num_retvals) {
ScopedMemoryDebugAnnotation op_annotation(
op->op_name(), op->remote_func_params().has_value()
? op->remote_func_params().value().step_id.value_or(0)
: 0);
profiler::TraceMe activity(
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo);
Status GetOrCreateKernelAndDevice(
EagerOperation* op, TensorHandle** retvals, int* num_retvals,
core::RefCountPtr<KernelAndDevice>* out_kernel) {
EagerContext& ctx = op->EagerContext();
auto& executor = op->Executor();
TF_RETURN_IF_ERROR(executor.status());
Device* device = absl::get<Device*>(op->Device());
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
@ -416,9 +394,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
TensorHandle* input = op->Inputs()[i];
if (!ctx.LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(EagerCopyToDevice(
input, &ctx, &executor, device == nullptr ? ctx.HostCPU() : device,
/* mirror= */ true, &handle));
TF_RETURN_IF_ERROR(
EagerCopyToDevice(input, &ctx, &op->Executor(),
device == nullptr ? ctx.HostCPU() : device,
/*mirror=*/true, &handle));
op->UpdateInput(i, handle);
// Unref handle since it has a ref as an input now
handle->Unref();
@ -569,6 +548,42 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
}
}
}
kernel->Ref(); // Ownership of reference is passed to out_kernel.
out_kernel->reset(kernel.get());
return Status::OK();
}
// There are a lot of references to devices in this function and around.
// Here is what they mean:
// EagerOperation::Device(): The device on which the user requested the op
// be executed, except if we had to change the device due to resource inputs
// or CPU pinning. If the user did not request a device, the op does not
// take resources, and we did not pin it to CPU, the device can be nullptr.
// KernelAndDevice::Device(): The first time we see an op (combined with
// its attributes), we need to create a KernelAndDevice object for it.
// If op->Device() is a nullptr, we select a device for the op when
// creating the KernelAndDevice. A concrete device will always be selected
// here except when `op` is a function to be executed using function library
// runtime. In this case, we don't select a device because running
// a function with explicitly requested device has different behavior than
// running without an explicitly requested device.
Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
int* num_retvals) {
ScopedMemoryDebugAnnotation op_annotation(
op->op_name(), op->remote_func_params().has_value()
? op->remote_func_params().value().step_id.value_or(0)
: 0);
profiler::TraceMe activity(
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo);
EagerContext& ctx = op->EagerContext();
auto& executor = op->Executor();
TF_RETURN_IF_ERROR(executor.status());
core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(
GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel));
int num_outputs = kernel->num_outputs();
if (num_outputs > *num_retvals) {
return errors::InvalidArgument("Expecting ", num_outputs,
@ -986,6 +1001,61 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
return Status::OK();
}
Status GetKernelOutputs(std::vector<Tensor>* outputs, int num_outputs,
TensorHandle** retvals, EagerContext* ctx,
KernelAndDevice* kernel) {
for (int i = 0; i < num_outputs; ++i) {
if (retvals[i] == nullptr) {
retvals[i] = TensorHandle::CreateLocalHandle(
std::move((*outputs)[i]),
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i), ctx);
} else {
if (TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
return errors::Internal(
"Kernel output tensor handle has a different op device than the "
"kernel. This should never happen.");
}
if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
absl::get<Device*>(retvals[i]->device()))) {
return errors::Internal(
"Kernel output tensor handle locates on a different device than "
"the specified kernel output device. This should never happen.");
}
TF_RETURN_IF_ERROR(
retvals[i]->SetTensor(std::move((*outputs)[i]),
ctx->CanonicalDevice(kernel->OutputDevice(i))));
}
}
return Status::OK();
}
void CollectGraphs(EagerContext* ctx) {
mutex_lock ml(*ctx->MetadataMu());
GraphCollector* collector = ctx->GetGraphCollector();
mutex_lock mll(collector->mu);
// Adding to partition graphs for backward compatibility.
for (const auto& graph : collector->partitioned_graphs) {
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
}
if (collector->dirty) {
auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
*function_graphs->mutable_post_optimization_graph() =
collector->optimized_graph;
*function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
for (const auto& graph : collector->partitioned_graphs) {
*function_graphs->add_partition_graphs() = graph;
}
}
collector->ClearGraphs();
}
} // namespace
Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
@ -1061,50 +1131,18 @@ Status EagerKernelExecute(
TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
cancellation_manager, remote_func_params));
if (graph_collector != nullptr) {
mutex_lock ml(*ctx->MetadataMu());
{
GraphCollector* collector = ctx->GetGraphCollector();
mutex_lock mll(collector->mu);
// Adding to partition graphs for backward compatibility.
for (const auto& graph : collector->partitioned_graphs) {
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
}
if (collector->dirty) {
auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
*function_graphs->mutable_post_optimization_graph() =
collector->optimized_graph;
*function_graphs->mutable_pre_optimization_graph() =
collector->raw_graph;
for (const auto& graph : collector->partitioned_graphs) {
*function_graphs->add_partition_graphs() = graph;
}
}
collector->ClearGraphs();
}
CollectGraphs(ctx);
}
DCHECK_EQ(retvals.size(), outputs.size());
for (int i = 0; i < retvals.size(); ++i) {
if (retvals[i] == nullptr) {
retvals[i] = TensorHandle::CreateLocalHandle(
std::move(outputs[i]),
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i), ctx);
} else {
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
absl::get<Device*>(retvals[i]->device()));
TF_RETURN_IF_ERROR(
retvals[i]->SetTensor(std::move(outputs[i]),
ctx->CanonicalDevice(kernel->OutputDevice(i))));
}
if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
return errors::Internal(
"EagerKernelExecute returns a list of ", outputs.size(),
" tensors but ", retvals.size(),
" is expected. This should never "
"happen. Please file a bug with the TensorFlow team.");
}
return Status::OK();
return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
kernel.get());
}
namespace {

View File

@ -44,8 +44,10 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
@ -89,6 +91,94 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
return Status::OK();
}
Status GetEagerOperation(const Operation& operation,
EagerContext* eager_context,
EagerExecutor* eager_executor,
EagerOperation* eager_op) {
const char* name = operation.name().c_str(); // Shorthand
absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
absl::nullopt;
if (operation.is_function()) {
if (operation.is_component_function()) {
remote_func_params = {operation.id(), operation.func_step_id()};
} else {
remote_func_params = {operation.id(), absl::nullopt};
}
}
TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
eager_executor, remote_func_params));
{
profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
profiler::TraceMeLevel::kVerbose);
for (const auto& input : operation.op_inputs()) {
tensorflow::TensorHandle* handle;
if (input.has_remote_handle()) {
TF_RETURN_IF_ERROR(
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
input.remote_handle(), &handle));
TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
} else {
Tensor tensor;
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
return errors::InvalidArgument("Invalid TensorProto: ",
input.tensor().DebugString());
} else {
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
nullptr, eager_context);
TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
}
}
// Unref handle since it has a ref as an input now.
handle->Unref();
}
}
for (const auto& attr : operation.attrs()) {
eager_op->MutableAttrs()->Set(attr.first, attr.second);
}
return Status::OK();
}
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
const tensorflow::Tensor* t = nullptr;
TF_RETURN_IF_ERROR(handle->Tensor(&t));
t->AsProtoTensorContent(proto);
return Status::OK();
}
Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
const tensorflow::Tensor* t = nullptr;
// TODO(nareshmodi): This call makes async calls sync calls. Fix this.
TF_RETURN_IF_ERROR(handle->Tensor(&t));
t->shape().AsProto(proto);
return Status::OK();
}
Status AddOpRetvalsToResponse(
EagerContext* eager_context, int op_id, int num_retvals,
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
std::function<TensorShapeProto*()> add_shape_proto_fn) {
if (op_id == kInvalidRemoteOpId) {
// Copy the output tensors back along with the response, since the op id
// is invalid which cannot be added to RemoteMgr.
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
retvals[i]->Unref();
}
} else {
eager_context->RemoteMgr()->AddOperationOutputs(
absl::MakeSpan(retvals, num_retvals), op_id);
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
}
}
return Status::OK();
}
} // namespace
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
@ -316,72 +406,13 @@ Status EagerServiceImpl::CreateMasterContext(
return Status::OK();
}
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
const tensorflow::Tensor* t = nullptr;
TF_RETURN_IF_ERROR(handle->Tensor(&t));
t->AsProtoTensorContent(proto);
return Status::OK();
}
Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
const tensorflow::Tensor* t = nullptr;
// TODO(nareshmodi): This call makes async calls sync calls. Fix this.
TF_RETURN_IF_ERROR(handle->Tensor(&t));
t->shape().AsProto(proto);
return Status::OK();
}
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
EagerContext* eager_context,
EagerExecutor* eager_executor,
QueueResponse* queue_response) {
std::unique_ptr<tensorflow::EagerOperation> op;
const char* name = operation.name().c_str(); // Shorthand
absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
absl::nullopt;
if (operation.is_function()) {
if (operation.is_component_function()) {
remote_func_params = {operation.id(), operation.func_step_id()};
} else {
remote_func_params = {operation.id(), absl::nullopt};
}
}
op.reset(new tensorflow::EagerOperation(eager_context));
TF_RETURN_IF_ERROR(op->Reset(name, operation.device().c_str(), false,
eager_executor, remote_func_params));
{
profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
profiler::TraceMeLevel::kVerbose);
for (const auto& input : operation.op_inputs()) {
tensorflow::TensorHandle* handle;
if (input.has_remote_handle()) {
TF_RETURN_IF_ERROR(
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
input.remote_handle(), &handle));
TF_RETURN_IF_ERROR(op->AddInput(handle));
} else {
Tensor tensor;
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
return errors::InvalidArgument("Invalid TensorProto: ",
input.tensor().DebugString());
} else {
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
nullptr, eager_context);
TF_RETURN_IF_ERROR(op->AddInput(handle));
}
}
// Unref handle since it has a ref as an input now.
handle->Unref();
}
}
for (const auto& attr : operation.attrs()) {
op->MutableAttrs()->Set(attr.first, attr.second);
}
tensorflow::EagerOperation op(eager_context);
TF_RETURN_IF_ERROR(
GetEagerOperation(operation, eager_context, eager_executor, &op));
int num_retvals = 0;
// TODO(nareshmodi): Consider caching this.
@ -390,26 +421,12 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals));
TF_RETURN_IF_ERROR(EagerExecute(&op, retvals.data(), &num_retvals));
if (operation.id() == kInvalidRemoteOpId) {
// Copy the output tensors back along with the response, since the op id
// is invalid which cannot be added to RemoteMgr.
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(
TensorHandleProto(retvals[i], queue_response->add_tensor()));
retvals[i]->Unref();
}
} else {
eager_context->RemoteMgr()->AddOperationOutputs(
absl::MakeSpan(retvals.data(), num_retvals), operation.id());
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(
TensorHandleShape(retvals[i], queue_response->add_shape()));
}
}
return Status::OK();
return AddOpRetvalsToResponse(
eager_context, operation.id(), num_retvals, retvals.data(),
[queue_response] { return queue_response->add_tensor(); },
[queue_response] { return queue_response->add_shape(); });
}
Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,