Non-functional change in preparation for introducing non-blocking function execution.
Modularize components in eager execute and eager service implementation. They will be shared and reused in the upcoming async code path. PiperOrigin-RevId: 308347746 Change-Id: Ida9cca1a1a88d3e6509c61950f4eaa4f18dbe864
This commit is contained in:
parent
5f543e0fa1
commit
47995cbaf7
@ -357,32 +357,10 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// There are a lot of references to devices in this function and around.
|
||||
// Here is what they mean:
|
||||
// EagerOperation::Device(): The device on which the user requested the op
|
||||
// be executed, except if we had to change the device due to resource inputs
|
||||
// or CPU pinning. If the user did not request a device, the op does not
|
||||
// take resources, and we did not pin it to CPU, the device can be nullptr.
|
||||
// KernelAndDevice::Device(): The first time we see an op (combined with
|
||||
// its attributes), we need to create a KernelAndDevice object for it.
|
||||
// If op->Device() is a nullptr, we select a device for the op when
|
||||
// creating the KernelAndDevice. A concrete device will always be selected
|
||||
// here except when `op` is a function to be executed using function library
|
||||
// runtime. In this case, we don't select a device because running
|
||||
// a function with explicitly requested device has different behavior than
|
||||
// running without an explicitly requested device.
|
||||
Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals) {
|
||||
ScopedMemoryDebugAnnotation op_annotation(
|
||||
op->op_name(), op->remote_func_params().has_value()
|
||||
? op->remote_func_params().value().step_id.value_or(0)
|
||||
: 0);
|
||||
profiler::TraceMe activity(
|
||||
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
Status GetOrCreateKernelAndDevice(
|
||||
EagerOperation* op, TensorHandle** retvals, int* num_retvals,
|
||||
core::RefCountPtr<KernelAndDevice>* out_kernel) {
|
||||
EagerContext& ctx = op->EagerContext();
|
||||
auto& executor = op->Executor();
|
||||
TF_RETURN_IF_ERROR(executor.status());
|
||||
Device* device = absl::get<Device*>(op->Device());
|
||||
|
||||
Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName());
|
||||
@ -416,9 +394,10 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
TensorHandle* input = op->Inputs()[i];
|
||||
if (!ctx.LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
||||
input, &ctx, &executor, device == nullptr ? ctx.HostCPU() : device,
|
||||
/* mirror= */ true, &handle));
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerCopyToDevice(input, &ctx, &op->Executor(),
|
||||
device == nullptr ? ctx.HostCPU() : device,
|
||||
/*mirror=*/true, &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
handle->Unref();
|
||||
@ -569,6 +548,42 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
}
|
||||
}
|
||||
}
|
||||
kernel->Ref(); // Ownership of reference is passed to out_kernel.
|
||||
out_kernel->reset(kernel.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// There are a lot of references to devices in this function and around.
|
||||
// Here is what they mean:
|
||||
// EagerOperation::Device(): The device on which the user requested the op
|
||||
// be executed, except if we had to change the device due to resource inputs
|
||||
// or CPU pinning. If the user did not request a device, the op does not
|
||||
// take resources, and we did not pin it to CPU, the device can be nullptr.
|
||||
// KernelAndDevice::Device(): The first time we see an op (combined with
|
||||
// its attributes), we need to create a KernelAndDevice object for it.
|
||||
// If op->Device() is a nullptr, we select a device for the op when
|
||||
// creating the KernelAndDevice. A concrete device will always be selected
|
||||
// here except when `op` is a function to be executed using function library
|
||||
// runtime. In this case, we don't select a device because running
|
||||
// a function with explicitly requested device has different behavior than
|
||||
// running without an explicitly requested device.
|
||||
Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals) {
|
||||
ScopedMemoryDebugAnnotation op_annotation(
|
||||
op->op_name(), op->remote_func_params().has_value()
|
||||
? op->remote_func_params().value().step_id.value_or(0)
|
||||
: 0);
|
||||
profiler::TraceMe activity(
|
||||
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
EagerContext& ctx = op->EagerContext();
|
||||
auto& executor = op->Executor();
|
||||
TF_RETURN_IF_ERROR(executor.status());
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> kernel;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel));
|
||||
|
||||
int num_outputs = kernel->num_outputs();
|
||||
if (num_outputs > *num_retvals) {
|
||||
return errors::InvalidArgument("Expecting ", num_outputs,
|
||||
@ -986,6 +1001,61 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetKernelOutputs(std::vector<Tensor>* outputs, int num_outputs,
|
||||
TensorHandle** retvals, EagerContext* ctx,
|
||||
KernelAndDevice* kernel) {
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
if (retvals[i] == nullptr) {
|
||||
retvals[i] = TensorHandle::CreateLocalHandle(
|
||||
std::move((*outputs)[i]),
|
||||
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
||||
/* op_device= */ kernel->device(),
|
||||
/* resource_device= */ kernel->OutputResourceDevice(i), ctx);
|
||||
} else {
|
||||
if (TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) {
|
||||
return errors::Internal(
|
||||
"Kernel output tensor handle has a different op device than the "
|
||||
"kernel. This should never happen.");
|
||||
}
|
||||
if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
|
||||
absl::get<Device*>(retvals[i]->device()))) {
|
||||
return errors::Internal(
|
||||
"Kernel output tensor handle locates on a different device than "
|
||||
"the specified kernel output device. This should never happen.");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
retvals[i]->SetTensor(std::move((*outputs)[i]),
|
||||
ctx->CanonicalDevice(kernel->OutputDevice(i))));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CollectGraphs(EagerContext* ctx) {
|
||||
mutex_lock ml(*ctx->MetadataMu());
|
||||
|
||||
GraphCollector* collector = ctx->GetGraphCollector();
|
||||
mutex_lock mll(collector->mu);
|
||||
|
||||
// Adding to partition graphs for backward compatibility.
|
||||
for (const auto& graph : collector->partitioned_graphs) {
|
||||
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
|
||||
}
|
||||
|
||||
if (collector->dirty) {
|
||||
auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
|
||||
*function_graphs->mutable_post_optimization_graph() =
|
||||
collector->optimized_graph;
|
||||
*function_graphs->mutable_pre_optimization_graph() = collector->raw_graph;
|
||||
for (const auto& graph : collector->partitioned_graphs) {
|
||||
*function_graphs->add_partition_graphs() = graph;
|
||||
}
|
||||
}
|
||||
|
||||
collector->ClearGraphs();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
@ -1061,50 +1131,18 @@ Status EagerKernelExecute(
|
||||
TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs,
|
||||
cancellation_manager, remote_func_params));
|
||||
if (graph_collector != nullptr) {
|
||||
mutex_lock ml(*ctx->MetadataMu());
|
||||
{
|
||||
GraphCollector* collector = ctx->GetGraphCollector();
|
||||
mutex_lock mll(collector->mu);
|
||||
|
||||
// Adding to partition graphs for backward compatibility.
|
||||
for (const auto& graph : collector->partitioned_graphs) {
|
||||
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
|
||||
}
|
||||
|
||||
if (collector->dirty) {
|
||||
auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs();
|
||||
*function_graphs->mutable_post_optimization_graph() =
|
||||
collector->optimized_graph;
|
||||
*function_graphs->mutable_pre_optimization_graph() =
|
||||
collector->raw_graph;
|
||||
for (const auto& graph : collector->partitioned_graphs) {
|
||||
*function_graphs->add_partition_graphs() = graph;
|
||||
}
|
||||
}
|
||||
|
||||
collector->ClearGraphs();
|
||||
}
|
||||
CollectGraphs(ctx);
|
||||
}
|
||||
DCHECK_EQ(retvals.size(), outputs.size());
|
||||
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
if (retvals[i] == nullptr) {
|
||||
retvals[i] = TensorHandle::CreateLocalHandle(
|
||||
std::move(outputs[i]),
|
||||
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
||||
/* op_device= */ kernel->device(),
|
||||
/* resource_device= */ kernel->OutputResourceDevice(i), ctx);
|
||||
} else {
|
||||
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
|
||||
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
||||
absl::get<Device*>(retvals[i]->device()));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
retvals[i]->SetTensor(std::move(outputs[i]),
|
||||
ctx->CanonicalDevice(kernel->OutputDevice(i))));
|
||||
}
|
||||
if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) {
|
||||
return errors::Internal(
|
||||
"EagerKernelExecute returns a list of ", outputs.size(),
|
||||
" tensors but ", retvals.size(),
|
||||
" is expected. This should never "
|
||||
"happen. Please file a bug with the TensorFlow team.");
|
||||
}
|
||||
return Status::OK();
|
||||
return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx,
|
||||
kernel.get());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -44,8 +44,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
@ -89,6 +91,94 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetEagerOperation(const Operation& operation,
|
||||
EagerContext* eager_context,
|
||||
EagerExecutor* eager_executor,
|
||||
EagerOperation* eager_op) {
|
||||
const char* name = operation.name().c_str(); // Shorthand
|
||||
absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
|
||||
absl::nullopt;
|
||||
if (operation.is_function()) {
|
||||
if (operation.is_component_function()) {
|
||||
remote_func_params = {operation.id(), operation.func_step_id()};
|
||||
} else {
|
||||
remote_func_params = {operation.id(), absl::nullopt};
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
|
||||
eager_executor, remote_func_params));
|
||||
|
||||
{
|
||||
profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
|
||||
profiler::TraceMeLevel::kVerbose);
|
||||
for (const auto& input : operation.op_inputs()) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
if (input.has_remote_handle()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
||||
input.remote_handle(), &handle));
|
||||
TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
|
||||
} else {
|
||||
Tensor tensor;
|
||||
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto: ",
|
||||
input.tensor().DebugString());
|
||||
} else {
|
||||
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
|
||||
nullptr, eager_context);
|
||||
TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
|
||||
}
|
||||
}
|
||||
// Unref handle since it has a ref as an input now.
|
||||
handle->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& attr : operation.attrs()) {
|
||||
eager_op->MutableAttrs()->Set(attr.first, attr.second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(&t));
|
||||
t->AsProtoTensorContent(proto);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
|
||||
// TODO(nareshmodi): This call makes async calls sync calls. Fix this.
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(&t));
|
||||
|
||||
t->shape().AsProto(proto);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddOpRetvalsToResponse(
|
||||
EagerContext* eager_context, int op_id, int num_retvals,
|
||||
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
|
||||
std::function<TensorShapeProto*()> add_shape_proto_fn) {
|
||||
if (op_id == kInvalidRemoteOpId) {
|
||||
// Copy the output tensors back along with the response, since the op id
|
||||
// is invalid which cannot be added to RemoteMgr.
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
|
||||
retvals[i]->Unref();
|
||||
}
|
||||
} else {
|
||||
eager_context->RemoteMgr()->AddOperationOutputs(
|
||||
absl::MakeSpan(retvals, num_retvals), op_id);
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
@ -316,72 +406,13 @@ Status EagerServiceImpl::CreateMasterContext(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(&t));
|
||||
t->AsProtoTensorContent(proto);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
|
||||
// TODO(nareshmodi): This call makes async calls sync calls. Fix this.
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(&t));
|
||||
|
||||
t->shape().AsProto(proto);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||
EagerContext* eager_context,
|
||||
EagerExecutor* eager_executor,
|
||||
QueueResponse* queue_response) {
|
||||
std::unique_ptr<tensorflow::EagerOperation> op;
|
||||
const char* name = operation.name().c_str(); // Shorthand
|
||||
absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
|
||||
absl::nullopt;
|
||||
if (operation.is_function()) {
|
||||
if (operation.is_component_function()) {
|
||||
remote_func_params = {operation.id(), operation.func_step_id()};
|
||||
} else {
|
||||
remote_func_params = {operation.id(), absl::nullopt};
|
||||
}
|
||||
}
|
||||
op.reset(new tensorflow::EagerOperation(eager_context));
|
||||
TF_RETURN_IF_ERROR(op->Reset(name, operation.device().c_str(), false,
|
||||
eager_executor, remote_func_params));
|
||||
|
||||
{
|
||||
profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
|
||||
profiler::TraceMeLevel::kVerbose);
|
||||
for (const auto& input : operation.op_inputs()) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
if (input.has_remote_handle()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
||||
input.remote_handle(), &handle));
|
||||
TF_RETURN_IF_ERROR(op->AddInput(handle));
|
||||
} else {
|
||||
Tensor tensor;
|
||||
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
|
||||
return errors::InvalidArgument("Invalid TensorProto: ",
|
||||
input.tensor().DebugString());
|
||||
} else {
|
||||
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
|
||||
nullptr, eager_context);
|
||||
TF_RETURN_IF_ERROR(op->AddInput(handle));
|
||||
}
|
||||
}
|
||||
// Unref handle since it has a ref as an input now.
|
||||
handle->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& attr : operation.attrs()) {
|
||||
op->MutableAttrs()->Set(attr.first, attr.second);
|
||||
}
|
||||
tensorflow::EagerOperation op(eager_context);
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetEagerOperation(operation, eager_context, eager_executor, &op));
|
||||
|
||||
int num_retvals = 0;
|
||||
// TODO(nareshmodi): Consider caching this.
|
||||
@ -390,26 +421,12 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||
|
||||
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
|
||||
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
|
||||
TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals));
|
||||
TF_RETURN_IF_ERROR(EagerExecute(&op, retvals.data(), &num_retvals));
|
||||
|
||||
if (operation.id() == kInvalidRemoteOpId) {
|
||||
// Copy the output tensors back along with the response, since the op id
|
||||
// is invalid which cannot be added to RemoteMgr.
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorHandleProto(retvals[i], queue_response->add_tensor()));
|
||||
retvals[i]->Unref();
|
||||
}
|
||||
} else {
|
||||
eager_context->RemoteMgr()->AddOperationOutputs(
|
||||
absl::MakeSpan(retvals.data(), num_retvals), operation.id());
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorHandleShape(retvals[i], queue_response->add_shape()));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
return AddOpRetvalsToResponse(
|
||||
eager_context, operation.id(), num_retvals, retvals.data(),
|
||||
[queue_response] { return queue_response->add_tensor(); },
|
||||
[queue_response] { return queue_response->add_shape(); });
|
||||
}
|
||||
|
||||
Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
||||
|
Loading…
Reference in New Issue
Block a user