Refactor eager placement logic into three util methods:
- MaybePinSmallOpsToCpu - MaybePinToResourceDevice - MaybePinToCustomDevice We are going to reuse MaybePinSmallOpsToCpu in TFRT but not the other two. Because TFRT doesn't have native Resource neither Custom Device. PiperOrigin-RevId: 317766813 Change-Id: I43241b5786120ddf39dc4bfff6071239afdfd785
This commit is contained in:
parent
c692c45dae
commit
fe6e64b098
|
@ -29,6 +29,7 @@ tf_cuda_library(
|
||||||
":context",
|
":context",
|
||||||
":eager_operation",
|
":eager_operation",
|
||||||
":execute",
|
":execute",
|
||||||
|
":placement_utils",
|
||||||
":tensor_handle",
|
":tensor_handle",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/c:tf_tensor_internal",
|
"//tensorflow/c:tf_tensor_internal",
|
||||||
|
@ -489,6 +490,7 @@ cc_library(
|
||||||
":eager_op_rewrite_registry",
|
":eager_op_rewrite_registry",
|
||||||
":eager_operation",
|
":eager_operation",
|
||||||
":kernel_and_device",
|
":kernel_and_device",
|
||||||
|
":placement_utils",
|
||||||
":tensor_handle",
|
":tensor_handle",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
|
@ -521,6 +523,35 @@ cc_library(
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "placement_utils",
|
||||||
|
srcs = [
|
||||||
|
"placement_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"placement_utils.h",
|
||||||
|
],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
":context",
|
||||||
|
":attr_builder",
|
||||||
|
":eager_operation",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
] + select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/core:core_cpu_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "attr_builder",
|
name = "attr_builder",
|
||||||
srcs = ["attr_builder.cc"],
|
srcs = ["attr_builder.cc"],
|
||||||
|
|
|
@ -478,7 +478,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||||
// On mobile, it just cleans the caches.
|
// On mobile, it just cleans the caches.
|
||||||
void WaitForAndCloseRemoteContexts();
|
void WaitForAndCloseRemoteContexts();
|
||||||
|
|
||||||
bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
|
bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
|
||||||
|
|
||||||
tensorflow::Env* TFEnv() const { return env_; }
|
tensorflow::Env* TFEnv() const { return env_; }
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
|
@ -187,6 +188,27 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) {
|
||||||
// eager_operation.cc we can avoid a circular dependency between them.
|
// eager_operation.cc we can avoid a circular dependency between them.
|
||||||
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
int* num_retvals) {
|
int* num_retvals) {
|
||||||
|
// Run eager placement logic.
|
||||||
|
VariantDevice device;
|
||||||
|
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));
|
||||||
|
if (device == kVariantDeviceNull) {
|
||||||
|
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
|
||||||
|
}
|
||||||
|
if (device == kVariantDeviceNull) {
|
||||||
|
bool pin_to_cpu;
|
||||||
|
TF_RETURN_IF_ERROR(eager::MaybePinSmallOpsToCpu(
|
||||||
|
&pin_to_cpu, op_name(),
|
||||||
|
absl::MakeSpan(
|
||||||
|
reinterpret_cast<ImmediateExecutionTensorHandle**>(inputs_.data()),
|
||||||
|
inputs_.size()),
|
||||||
|
ctx_));
|
||||||
|
if (pin_to_cpu) {
|
||||||
|
device = ctx_.HostCPU();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (device != kVariantDeviceNull) {
|
||||||
|
SetDevice(device);
|
||||||
|
}
|
||||||
return EagerExecute(
|
return EagerExecute(
|
||||||
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
|
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
|
||||||
num_retvals);
|
num_retvals);
|
||||||
|
|
|
@ -126,7 +126,7 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||||
bool is_function() const { return is_function_; }
|
bool is_function() const { return is_function_; }
|
||||||
bool colocation_exempt() const { return colocation_exempt_; }
|
bool colocation_exempt() const { return colocation_exempt_; }
|
||||||
|
|
||||||
tensorflow::EagerContext& EagerContext() { return ctx_; }
|
tensorflow::EagerContext& EagerContext() const { return ctx_; }
|
||||||
|
|
||||||
AttrBuilder* MutableAttrs() { return &attrs_; }
|
AttrBuilder* MutableAttrs() { return &attrs_; }
|
||||||
const AttrBuilder& Attrs() const { return attrs_; }
|
const AttrBuilder& Attrs() const { return attrs_; }
|
||||||
|
|
|
@ -870,173 +870,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||||
}
|
}
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
// These ops are not pinnable since they generate data. It can be slower to
|
|
||||||
// generate and then copy the data instead of just generating the data on the
|
|
||||||
// device directly.
|
|
||||||
bool IsPinnableOp(const string& op_type) {
|
|
||||||
static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
|
|
||||||
"RandomUniform",
|
|
||||||
"RandomUniformInt",
|
|
||||||
"RandomStandardNormal",
|
|
||||||
"StatelessRandomUniform",
|
|
||||||
"StatelessRandomUniformInt",
|
|
||||||
"StatelessRandomUniformFullInt",
|
|
||||||
"StatelessRandomNormal",
|
|
||||||
});
|
|
||||||
|
|
||||||
// XRT ops refer to per-device handles that are not safe to move between
|
|
||||||
// devices.
|
|
||||||
return unpinnable_ops->find(op_type) == unpinnable_ops->end() &&
|
|
||||||
!absl::StartsWith(op_type, "XRT");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate if the remote device with the given incarnation is valid in the
|
|
||||||
// remote device manager of the current eager context.
|
|
||||||
Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
|
|
||||||
int64 device_incarnation) {
|
|
||||||
if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Resource input tensor contains an invalid device. This might happen "
|
|
||||||
"when the client has connected to a different cluster, or some remote "
|
|
||||||
"workers have been restarted.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// The Op device may be updated if:
|
|
||||||
// - A resource touching input is specified: all resource-touching ops run in
|
|
||||||
// the device the resource is, regardless of anything else that has been
|
|
||||||
// specified. This is identical to the graph mode behavior.
|
|
||||||
//
|
|
||||||
// - All op inputs are on the CPU, small (<64 elements) and integers
|
|
||||||
// (int32/int64). This can be disabled by setting the environment variable
|
|
||||||
// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
|
|
||||||
//
|
|
||||||
// TODO(b/154234908): Unify placement logic.
|
|
||||||
Status MaybeUpdateOpDevice(EagerOperation* op) {
|
|
||||||
// If operation was already placed on a custom device, use it.
|
|
||||||
if (VariantDeviceIsCustom(op->Device())) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If all the inputs are on the same custom device, use that custom
|
|
||||||
// device. Otherwise, it is an error to have a custom device as an input.
|
|
||||||
if (!op->Inputs().empty()) {
|
|
||||||
// We keep track of what we've seen with devices instead of booleans to be
|
|
||||||
// able to provide a meaningful error message below.
|
|
||||||
VariantDevice first = op->Inputs()[0]->device();
|
|
||||||
VariantDevice different = first; // A different input device, if any.
|
|
||||||
VariantDevice custom = first; // The first custom device seen, or an
|
|
||||||
// arbitrary non-custom device otherwise.
|
|
||||||
for (size_t i = 1; first == different && i < op->Inputs().size(); ++i) {
|
|
||||||
VariantDevice device = op->Inputs()[i]->device();
|
|
||||||
if (device != first) {
|
|
||||||
different = device;
|
|
||||||
}
|
|
||||||
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
|
|
||||||
custom = device;
|
|
||||||
}
|
|
||||||
if (different != first && VariantDeviceIsCustom(custom)) {
|
|
||||||
return errors::InvalidArgument(absl::StrCat(
|
|
||||||
"If an operation has one of its inputs in a custom device, then "
|
|
||||||
"all inputs should be on that same device. Operation ",
|
|
||||||
op->Name(), " has one input in custom device ",
|
|
||||||
VariantDeviceName(custom),
|
|
||||||
" and at least one input in a different device ",
|
|
||||||
VariantDeviceName(custom == first ? different : first)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (different == first && VariantDeviceIsCustom(custom)) {
|
|
||||||
op->SetDevice(first);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (op->colocation_exempt()) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
EagerContext& ctx = op->EagerContext();
|
|
||||||
bool all_inputs_eligible_for_cpu_pinning =
|
|
||||||
ctx.PinSmallOpsToCPU() && !op->is_function() && IsPinnableOp(op->Name());
|
|
||||||
Device* op_device = op->Device() == kVariantDeviceNull
|
|
||||||
? ctx.HostCPU()
|
|
||||||
: absl::get<Device*>(op->Device());
|
|
||||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
|
||||||
TensorHandle* tensor_handle = op->Inputs()[i];
|
|
||||||
if (tensor_handle->dtype == DT_RESOURCE) {
|
|
||||||
if (tensor_handle->resource_remote_device_incarnation() != 0) {
|
|
||||||
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
|
|
||||||
&ctx, tensor_handle->resource_remote_device_incarnation()));
|
|
||||||
}
|
|
||||||
Device* resource_device = tensor_handle->resource_device();
|
|
||||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
|
||||||
<< DataTypeString(tensor_handle->dtype)
|
|
||||||
<< " input device = " << resource_device->name()
|
|
||||||
<< ", op device = " << op_device->name();
|
|
||||||
// We check for `op->Device() == nullptr` because it can be later
|
|
||||||
// interpreted as unspecified device and a different device can
|
|
||||||
// be selected based on device priority. If any input to an op
|
|
||||||
// is a resource we must pin it to prevent different device selection.
|
|
||||||
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
|
|
||||||
if (resource_device != op_device || op->Device() == kVariantDeviceNull) {
|
|
||||||
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
|
||||||
<< "device of operation " << op->Name() << " to "
|
|
||||||
<< resource_device->name() << " because input #" << i
|
|
||||||
<< " is a resource in this device.";
|
|
||||||
op->SetDevice(resource_device);
|
|
||||||
}
|
|
||||||
all_inputs_eligible_for_cpu_pinning = false;
|
|
||||||
// No point in looking at other inputs. If there are other resources,
|
|
||||||
// they must have the same device and we already declared the op to be
|
|
||||||
// ineligible for CPU pinning.
|
|
||||||
break;
|
|
||||||
} else if (all_inputs_eligible_for_cpu_pinning) {
|
|
||||||
auto input_device_variant = tensor_handle->DeviceOrHostCPU(ctx);
|
|
||||||
if (VariantDeviceIsCustom(input_device_variant)) {
|
|
||||||
all_inputs_eligible_for_cpu_pinning = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Device* input_device = absl::get<Device*>(input_device_variant);
|
|
||||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
|
||||||
<< DataTypeString(tensor_handle->dtype)
|
|
||||||
<< " input device = " << input_device->name()
|
|
||||||
<< ", op device = " << op_device->name();
|
|
||||||
|
|
||||||
// Input is on CPU.
|
|
||||||
if (input_device != ctx.HostCPU()) {
|
|
||||||
all_inputs_eligible_for_cpu_pinning = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tensor_handle->dtype != DataType::DT_INT32 &&
|
|
||||||
tensor_handle->dtype != DataType::DT_INT64) {
|
|
||||||
all_inputs_eligible_for_cpu_pinning = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 num_elements;
|
|
||||||
TF_RETURN_IF_ERROR(tensor_handle->NumElements(&num_elements));
|
|
||||||
if (num_elements > 64) {
|
|
||||||
all_inputs_eligible_for_cpu_pinning = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ops without inputs are usually ops that generate a tensor in some way and
|
|
||||||
// usually require being present on whatever device they are scheduled on
|
|
||||||
// - for e.g. VarHandleOp or _Recv).
|
|
||||||
// TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
|
|
||||||
// an op, but there is a GPU kernel?
|
|
||||||
if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
|
|
||||||
DVLOG(1) << "Forcing op " << op->Name()
|
|
||||||
<< " to be on the CPU since all input tensors have an "
|
|
||||||
"int32/int64 dtype, and are small (less than 64 elements).";
|
|
||||||
op->SetDevice(ctx.HostCPU());
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetKernelOutputs(std::vector<Tensor>* outputs, int num_outputs,
|
Status GetKernelOutputs(std::vector<Tensor>* outputs, int num_outputs,
|
||||||
TensorHandle** retvals, EagerContext* ctx,
|
TensorHandle** retvals, EagerContext* ctx,
|
||||||
KernelAndDevice* kernel) {
|
KernelAndDevice* kernel) {
|
||||||
|
@ -1099,8 +932,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
||||||
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
|
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
|
|
||||||
|
|
||||||
if (VariantDeviceIsCustom(op->Device())) {
|
if (VariantDeviceIsCustom(op->Device())) {
|
||||||
return absl::get<CustomDevice*>(op->Device())
|
return absl::get<CustomDevice*>(op->Device())
|
||||||
->Execute(op, retvals, num_retvals);
|
->Execute(op, retvals, num_retvals);
|
||||||
|
|
|
@ -0,0 +1,228 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
|
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||||
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace eager {
|
||||||
|
|
||||||
|
// These ops are not pinnable since they generate data. It can be slower to
|
||||||
|
// generate and then copy the data instead of just generating the data on the
|
||||||
|
// device directly.
|
||||||
|
static bool IsPinnableOp(StringPiece op_name) {
|
||||||
|
static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
|
||||||
|
"RandomUniform",
|
||||||
|
"RandomUniformInt",
|
||||||
|
"RandomStandardNormal",
|
||||||
|
"StatelessRandomUniform",
|
||||||
|
"StatelessRandomUniformInt",
|
||||||
|
"StatelessRandomUniformFullInt",
|
||||||
|
"StatelessRandomNormal",
|
||||||
|
});
|
||||||
|
|
||||||
|
// XRT ops refer to per-device handles that are not safe to move between
|
||||||
|
// devices.
|
||||||
|
return unpinnable_ops->find(string(op_name)) == unpinnable_ops->end() &&
|
||||||
|
!absl::StartsWith(op_name, "XRT");
|
||||||
|
}
|
||||||
|
// Validate if the remote device with the given incarnation is valid in the
|
||||||
|
// remote device manager of the current eager context.
|
||||||
|
static Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
|
||||||
|
int64 device_incarnation) {
|
||||||
|
if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Resource input tensor contains an invalid device. This might happen "
|
||||||
|
"when the client has connected to a different cluster, or some remote "
|
||||||
|
"workers have been restarted.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsColocationExempt(StringPiece op_name) {
|
||||||
|
const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
|
||||||
|
return exempt_ops.find(string(op_name)) != exempt_ops.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsFunction(StringPiece op_name) {
|
||||||
|
const OpDef* op_def = nullptr;
|
||||||
|
Status s = OpDefForOp(string(op_name), &op_def);
|
||||||
|
if (!s.ok()) {
|
||||||
|
if (!errors::IsNotFound(s)) {
|
||||||
|
LOG(WARNING) << "Looking up OpDef failed with error: " << s.ToString();
|
||||||
|
}
|
||||||
|
// Cannot find OpDef, it is a function.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
|
||||||
|
CustomDevice* custom_device;
|
||||||
|
return ctx.FindCustomDeviceFromName(string(device_name), &custom_device).ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MaybePinSmallOpsToCpu(bool* result, StringPiece op_name,
|
||||||
|
absl::Span<ImmediateExecutionTensorHandle*> args,
|
||||||
|
const EagerContext& ctx) {
|
||||||
|
if (!ctx.PinSmallOpsToCPU() || IsFunction(op_name) ||
|
||||||
|
IsColocationExempt(op_name) || !IsPinnableOp(op_name)) {
|
||||||
|
*result = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ops without inputs are usually ops that generate a tensor in some way and
|
||||||
|
// usually require being present on whatever device they are scheduled on
|
||||||
|
// - for e.g. VarHandleOp or _Recv).
|
||||||
|
if (args.empty()) {
|
||||||
|
*result = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for (auto* arg : args) {
|
||||||
|
Status s;
|
||||||
|
const char* device_name = arg->DeviceName(&s);
|
||||||
|
DataType dtype = arg->DataType();
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
|
if (IsCustomDevice(device_name, ctx)) {
|
||||||
|
*result = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DVLOG(2) << "for op " << op_name << " input " << i << " "
|
||||||
|
<< DataTypeString(dtype) << " input device = " << device_name;
|
||||||
|
|
||||||
|
// Input is on CPU.
|
||||||
|
if (device_name != ctx.HostCPU()->name()) {
|
||||||
|
*result = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dtype != DataType::DT_INT32 && dtype != DataType::DT_INT64) {
|
||||||
|
*result = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 num_elements;
|
||||||
|
TF_RETURN_IF_ERROR(arg->NumElements(&num_elements));
|
||||||
|
if (num_elements > 64) {
|
||||||
|
*result = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
|
||||||
|
// an op, but there is a GPU kernel?
|
||||||
|
DVLOG(1) << "Forcing op " << op_name
|
||||||
|
<< " to be on the CPU since all input tensors have an "
|
||||||
|
"int32/int64 dtype, and are small (less than 64 elements).";
|
||||||
|
*result = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MaybePinToResourceDevice(VariantDevice* device,
|
||||||
|
const EagerOperation& op) {
|
||||||
|
if (op.colocation_exempt()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
EagerContext& ctx = op.EagerContext();
|
||||||
|
Device* op_device = op.Device() == kVariantDeviceNull
|
||||||
|
? ctx.HostCPU()
|
||||||
|
: absl::get<Device*>(op.Device());
|
||||||
|
for (int i = 0; i < op.Inputs().size(); ++i) {
|
||||||
|
TensorHandle* tensor_handle = op.Inputs()[i];
|
||||||
|
if (tensor_handle->dtype == DT_RESOURCE) {
|
||||||
|
if (tensor_handle->resource_remote_device_incarnation() != 0) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
|
||||||
|
&ctx, tensor_handle->resource_remote_device_incarnation()));
|
||||||
|
}
|
||||||
|
Device* resource_device = tensor_handle->resource_device();
|
||||||
|
DVLOG(2) << "for op " << op.Name() << " input " << i << " "
|
||||||
|
<< DataTypeString(tensor_handle->dtype)
|
||||||
|
<< " input device = " << resource_device->name()
|
||||||
|
<< ", op device = " << op_device->name();
|
||||||
|
// We check for `op->Device() == nullptr` because it can be later
|
||||||
|
// interpreted as unspecified device and a different device can
|
||||||
|
// be selected based on device priority. If any input to an op
|
||||||
|
// is a resource we must pin it to prevent different device selection.
|
||||||
|
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
|
||||||
|
if (resource_device != op_device || op.Device() == kVariantDeviceNull) {
|
||||||
|
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
||||||
|
<< "device of operation " << op.Name() << " to "
|
||||||
|
<< resource_device->name() << " because input #" << i
|
||||||
|
<< " is a resource in this device.";
|
||||||
|
*device = resource_device;
|
||||||
|
return Status::OK();
|
||||||
|
// No point in looking at other inputs. If there are other resources,
|
||||||
|
// they must have the same device and we already declared the op to be
|
||||||
|
// ineligible for CPU pinning.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
||||||
|
// If operation was already placed on a custom device, use it.
|
||||||
|
if (VariantDeviceIsCustom(op.Device())) {
|
||||||
|
*device = op.Device();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!op.Inputs().empty()) {
|
||||||
|
// We keep track of what we've seen with devices instead of booleans to be
|
||||||
|
// able to provide a meaningful error message below.
|
||||||
|
VariantDevice first = op.Inputs()[0]->device();
|
||||||
|
VariantDevice different = first; // A different input device, if any.
|
||||||
|
VariantDevice custom = first; // The first custom device seen, or an
|
||||||
|
// arbitrary non-custom device otherwise.
|
||||||
|
for (size_t i = 1; first == different && i < op.Inputs().size(); ++i) {
|
||||||
|
VariantDevice device = op.Inputs()[i]->device();
|
||||||
|
if (device != first) {
|
||||||
|
different = device;
|
||||||
|
}
|
||||||
|
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
|
||||||
|
custom = device;
|
||||||
|
}
|
||||||
|
if (different != first && VariantDeviceIsCustom(custom)) {
|
||||||
|
return errors::InvalidArgument(absl::StrCat(
|
||||||
|
"If an operation has one of its inputs in a custom device, then "
|
||||||
|
"all inputs should be on that same device. Operation ",
|
||||||
|
op.Name(), " has one input in custom device ",
|
||||||
|
VariantDeviceName(custom),
|
||||||
|
" and at least one input in a different device ",
|
||||||
|
VariantDeviceName(custom == first ? different : first)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (different == first && VariantDeviceIsCustom(custom)) {
|
||||||
|
*device = first;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace eager
|
||||||
|
} // namespace tensorflow
|
|
@ -0,0 +1,55 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace eager {
|
||||||
|
|
||||||
|
bool IsColocationExempt(StringPiece op_name);
|
||||||
|
|
||||||
|
bool IsFunction(StringPiece op_name);
|
||||||
|
|
||||||
|
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
|
||||||
|
|
||||||
|
// TODO(b/154234908): Unify placement logic.
|
||||||
|
// TODO(b/159647422): Add C++ unit tests for placement logic.
|
||||||
|
|
||||||
|
// Pin the op to cpu if all op inputs are on the CPU, small (<64 elements) and
|
||||||
|
// integers (int32/int64). This can be disabled by setting the environment
|
||||||
|
// variable "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
|
||||||
|
Status MaybePinSmallOpsToCpu(bool* result, StringPiece op_name,
|
||||||
|
absl::Span<ImmediateExecutionTensorHandle*> args,
|
||||||
|
const EagerContext& ctx);
|
||||||
|
|
||||||
|
// If a resource touching input is specified, all resource-touching ops run in
|
||||||
|
// the device the resource is, regardless of anything else that has been
|
||||||
|
// specified. This is identical to the graph mode behavior.
|
||||||
|
Status MaybePinToResourceDevice(VariantDevice* device,
|
||||||
|
const EagerOperation& op);
|
||||||
|
|
||||||
|
// If all the inputs are on the same custom device, use that custom
|
||||||
|
// device. Otherwise, it is an error to have a custom device as an input.
|
||||||
|
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
|
||||||
|
|
||||||
|
} // namespace eager
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
|
@ -491,7 +491,11 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||||
|
|
||||||
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
|
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
|
||||||
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
|
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
|
||||||
TF_RETURN_IF_ERROR(EagerExecute(&op, retvals.data(), &num_retvals));
|
TF_RETURN_IF_ERROR(op.Execute(
|
||||||
|
absl::MakeSpan(
|
||||||
|
reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
|
||||||
|
num_retvals),
|
||||||
|
&num_retvals));
|
||||||
|
|
||||||
return AddOpRetvalsToResponse(
|
return AddOpRetvalsToResponse(
|
||||||
eager_context, operation.id(), num_retvals, retvals.data(),
|
eager_context, operation.id(), num_retvals, retvals.data(),
|
||||||
|
|
|
@ -331,8 +331,11 @@ tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
|
||||||
node_data->mutable_outputs()->ResetTensorHandles();
|
node_data->mutable_outputs()->ResetTensorHandles();
|
||||||
int num_retvals = node_data->NumOutputs();
|
int num_retvals = node_data->NumOutputs();
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
EagerExecute(node_data->op(),
|
node_data->op()->Execute(
|
||||||
node_data->mutable_outputs()->GetTensorHandles()->data(),
|
absl::MakeSpan(
|
||||||
|
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||||
|
node_data->mutable_outputs()->GetTensorHandles()->data()),
|
||||||
|
num_retvals),
|
||||||
&num_retvals),
|
&num_retvals),
|
||||||
" (while executing '", node_data->name(), "' via Eager)");
|
" (while executing '", node_data->name(), "' via Eager)");
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue