Refactor eager placement logic into three util methods:
- MaybePinSmallOpsToCpu - MaybePinToResourceDevice - MaybePinToCustomDevice We are going to reuse MaybePinSmallOpsToCpu in TFRT but not the other two. Because TFRT doesn't have native Resource neither Custom Device. PiperOrigin-RevId: 317766813 Change-Id: I43241b5786120ddf39dc4bfff6071239afdfd785
This commit is contained in:
parent
c692c45dae
commit
fe6e64b098
|
@ -29,6 +29,7 @@ tf_cuda_library(
|
|||
":context",
|
||||
":eager_operation",
|
||||
":execute",
|
||||
":placement_utils",
|
||||
":tensor_handle",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
|
@ -489,6 +490,7 @@ cc_library(
|
|||
":eager_op_rewrite_registry",
|
||||
":eager_operation",
|
||||
":kernel_and_device",
|
||||
":placement_utils",
|
||||
":tensor_handle",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
|
@ -521,6 +523,35 @@ cc_library(
|
|||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "placement_utils",
|
||||
srcs = [
|
||||
"placement_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"placement_utils.h",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":attr_builder",
|
||||
":eager_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "attr_builder",
|
||||
srcs = ["attr_builder.cc"],
|
||||
|
|
|
@ -478,7 +478,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||
// On mobile, it just cleans the caches.
|
||||
void WaitForAndCloseRemoteContexts();
|
||||
|
||||
bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
|
||||
bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
|
||||
|
||||
tensorflow::Env* TFEnv() const { return env_; }
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
|
@ -187,6 +188,27 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) {
|
|||
// eager_operation.cc we can avoid a circular dependency between them.
|
||||
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
// Run eager placement logic.
|
||||
VariantDevice device;
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));
|
||||
if (device == kVariantDeviceNull) {
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
|
||||
}
|
||||
if (device == kVariantDeviceNull) {
|
||||
bool pin_to_cpu;
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinSmallOpsToCpu(
|
||||
&pin_to_cpu, op_name(),
|
||||
absl::MakeSpan(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle**>(inputs_.data()),
|
||||
inputs_.size()),
|
||||
ctx_));
|
||||
if (pin_to_cpu) {
|
||||
device = ctx_.HostCPU();
|
||||
}
|
||||
}
|
||||
if (device != kVariantDeviceNull) {
|
||||
SetDevice(device);
|
||||
}
|
||||
return EagerExecute(
|
||||
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
|
||||
num_retvals);
|
||||
|
|
|
@ -126,7 +126,7 @@ class EagerOperation : public ImmediateExecutionOperation {
|
|||
bool is_function() const { return is_function_; }
|
||||
bool colocation_exempt() const { return colocation_exempt_; }
|
||||
|
||||
tensorflow::EagerContext& EagerContext() { return ctx_; }
|
||||
tensorflow::EagerContext& EagerContext() const { return ctx_; }
|
||||
|
||||
AttrBuilder* MutableAttrs() { return &attrs_; }
|
||||
const AttrBuilder& Attrs() const { return attrs_; }
|
||||
|
|
|
@ -870,173 +870,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||
}
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
// These ops are not pinnable since they generate data. It can be slower to
|
||||
// generate and then copy the data instead of just generating the data on the
|
||||
// device directly.
|
||||
bool IsPinnableOp(const string& op_type) {
|
||||
static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
|
||||
"RandomUniform",
|
||||
"RandomUniformInt",
|
||||
"RandomStandardNormal",
|
||||
"StatelessRandomUniform",
|
||||
"StatelessRandomUniformInt",
|
||||
"StatelessRandomUniformFullInt",
|
||||
"StatelessRandomNormal",
|
||||
});
|
||||
|
||||
// XRT ops refer to per-device handles that are not safe to move between
|
||||
// devices.
|
||||
return unpinnable_ops->find(op_type) == unpinnable_ops->end() &&
|
||||
!absl::StartsWith(op_type, "XRT");
|
||||
}
|
||||
|
||||
// Validate if the remote device with the given incarnation is valid in the
|
||||
// remote device manager of the current eager context.
|
||||
Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
|
||||
int64 device_incarnation) {
|
||||
if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
"Resource input tensor contains an invalid device. This might happen "
|
||||
"when the client has connected to a different cluster, or some remote "
|
||||
"workers have been restarted.");
|
||||
}
|
||||
|
||||
// The Op device may be updated if:
|
||||
// - A resource touching input is specified: all resource-touching ops run in
|
||||
// the device the resource is, regardless of anything else that has been
|
||||
// specified. This is identical to the graph mode behavior.
|
||||
//
|
||||
// - All op inputs are on the CPU, small (<64 elements) and integers
|
||||
// (int32/int64). This can be disabled by setting the environment variable
|
||||
// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
|
||||
//
|
||||
// TODO(b/154234908): Unify placement logic.
|
||||
Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
// If operation was already placed on a custom device, use it.
|
||||
if (VariantDeviceIsCustom(op->Device())) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If all the inputs are on the same custom device, use that custom
|
||||
// device. Otherwise, it is an error to have a custom device as an input.
|
||||
if (!op->Inputs().empty()) {
|
||||
// We keep track of what we've seen with devices instead of booleans to be
|
||||
// able to provide a meaningful error message below.
|
||||
VariantDevice first = op->Inputs()[0]->device();
|
||||
VariantDevice different = first; // A different input device, if any.
|
||||
VariantDevice custom = first; // The first custom device seen, or an
|
||||
// arbitrary non-custom device otherwise.
|
||||
for (size_t i = 1; first == different && i < op->Inputs().size(); ++i) {
|
||||
VariantDevice device = op->Inputs()[i]->device();
|
||||
if (device != first) {
|
||||
different = device;
|
||||
}
|
||||
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
|
||||
custom = device;
|
||||
}
|
||||
if (different != first && VariantDeviceIsCustom(custom)) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"If an operation has one of its inputs in a custom device, then "
|
||||
"all inputs should be on that same device. Operation ",
|
||||
op->Name(), " has one input in custom device ",
|
||||
VariantDeviceName(custom),
|
||||
" and at least one input in a different device ",
|
||||
VariantDeviceName(custom == first ? different : first)));
|
||||
}
|
||||
}
|
||||
if (different == first && VariantDeviceIsCustom(custom)) {
|
||||
op->SetDevice(first);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
if (op->colocation_exempt()) {
|
||||
return Status::OK();
|
||||
}
|
||||
EagerContext& ctx = op->EagerContext();
|
||||
bool all_inputs_eligible_for_cpu_pinning =
|
||||
ctx.PinSmallOpsToCPU() && !op->is_function() && IsPinnableOp(op->Name());
|
||||
Device* op_device = op->Device() == kVariantDeviceNull
|
||||
? ctx.HostCPU()
|
||||
: absl::get<Device*>(op->Device());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
TensorHandle* tensor_handle = op->Inputs()[i];
|
||||
if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
if (tensor_handle->resource_remote_device_incarnation() != 0) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
|
||||
&ctx, tensor_handle->resource_remote_device_incarnation()));
|
||||
}
|
||||
Device* resource_device = tensor_handle->resource_device();
|
||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
||||
<< DataTypeString(tensor_handle->dtype)
|
||||
<< " input device = " << resource_device->name()
|
||||
<< ", op device = " << op_device->name();
|
||||
// We check for `op->Device() == nullptr` because it can be later
|
||||
// interpreted as unspecified device and a different device can
|
||||
// be selected based on device priority. If any input to an op
|
||||
// is a resource we must pin it to prevent different device selection.
|
||||
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
|
||||
if (resource_device != op_device || op->Device() == kVariantDeviceNull) {
|
||||
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
||||
<< "device of operation " << op->Name() << " to "
|
||||
<< resource_device->name() << " because input #" << i
|
||||
<< " is a resource in this device.";
|
||||
op->SetDevice(resource_device);
|
||||
}
|
||||
all_inputs_eligible_for_cpu_pinning = false;
|
||||
// No point in looking at other inputs. If there are other resources,
|
||||
// they must have the same device and we already declared the op to be
|
||||
// ineligible for CPU pinning.
|
||||
break;
|
||||
} else if (all_inputs_eligible_for_cpu_pinning) {
|
||||
auto input_device_variant = tensor_handle->DeviceOrHostCPU(ctx);
|
||||
if (VariantDeviceIsCustom(input_device_variant)) {
|
||||
all_inputs_eligible_for_cpu_pinning = false;
|
||||
continue;
|
||||
}
|
||||
Device* input_device = absl::get<Device*>(input_device_variant);
|
||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
||||
<< DataTypeString(tensor_handle->dtype)
|
||||
<< " input device = " << input_device->name()
|
||||
<< ", op device = " << op_device->name();
|
||||
|
||||
// Input is on CPU.
|
||||
if (input_device != ctx.HostCPU()) {
|
||||
all_inputs_eligible_for_cpu_pinning = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tensor_handle->dtype != DataType::DT_INT32 &&
|
||||
tensor_handle->dtype != DataType::DT_INT64) {
|
||||
all_inputs_eligible_for_cpu_pinning = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
int64 num_elements;
|
||||
TF_RETURN_IF_ERROR(tensor_handle->NumElements(&num_elements));
|
||||
if (num_elements > 64) {
|
||||
all_inputs_eligible_for_cpu_pinning = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ops without inputs are usually ops that generate a tensor in some way and
|
||||
// usually require being present on whatever device they are scheduled on
|
||||
// - for e.g. VarHandleOp or _Recv).
|
||||
// TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
|
||||
// an op, but there is a GPU kernel?
|
||||
if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
|
||||
DVLOG(1) << "Forcing op " << op->Name()
|
||||
<< " to be on the CPU since all input tensors have an "
|
||||
"int32/int64 dtype, and are small (less than 64 elements).";
|
||||
op->SetDevice(ctx.HostCPU());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetKernelOutputs(std::vector<Tensor>* outputs, int num_outputs,
|
||||
TensorHandle** retvals, EagerContext* ctx,
|
||||
KernelAndDevice* kernel) {
|
||||
|
@ -1099,8 +932,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
|||
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
|
||||
|
||||
if (VariantDeviceIsCustom(op->Device())) {
|
||||
return absl::get<CustomDevice*>(op->Device())
|
||||
->Execute(op, retvals, num_retvals);
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// These ops are not pinnable since they generate data. It can be slower to
|
||||
// generate and then copy the data instead of just generating the data on the
|
||||
// device directly.
|
||||
static bool IsPinnableOp(StringPiece op_name) {
|
||||
static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
|
||||
"RandomUniform",
|
||||
"RandomUniformInt",
|
||||
"RandomStandardNormal",
|
||||
"StatelessRandomUniform",
|
||||
"StatelessRandomUniformInt",
|
||||
"StatelessRandomUniformFullInt",
|
||||
"StatelessRandomNormal",
|
||||
});
|
||||
|
||||
// XRT ops refer to per-device handles that are not safe to move between
|
||||
// devices.
|
||||
return unpinnable_ops->find(string(op_name)) == unpinnable_ops->end() &&
|
||||
!absl::StartsWith(op_name, "XRT");
|
||||
}
|
||||
// Validate if the remote device with the given incarnation is valid in the
|
||||
// remote device manager of the current eager context.
|
||||
static Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
|
||||
int64 device_incarnation) {
|
||||
if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
"Resource input tensor contains an invalid device. This might happen "
|
||||
"when the client has connected to a different cluster, or some remote "
|
||||
"workers have been restarted.");
|
||||
}
|
||||
|
||||
bool IsColocationExempt(StringPiece op_name) {
|
||||
const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
|
||||
return exempt_ops.find(string(op_name)) != exempt_ops.end();
|
||||
}
|
||||
|
||||
bool IsFunction(StringPiece op_name) {
|
||||
const OpDef* op_def = nullptr;
|
||||
Status s = OpDefForOp(string(op_name), &op_def);
|
||||
if (!s.ok()) {
|
||||
if (!errors::IsNotFound(s)) {
|
||||
LOG(WARNING) << "Looking up OpDef failed with error: " << s.ToString();
|
||||
}
|
||||
// Cannot find OpDef, it is a function.
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
|
||||
CustomDevice* custom_device;
|
||||
return ctx.FindCustomDeviceFromName(string(device_name), &custom_device).ok();
|
||||
}
|
||||
|
||||
Status MaybePinSmallOpsToCpu(bool* result, StringPiece op_name,
|
||||
absl::Span<ImmediateExecutionTensorHandle*> args,
|
||||
const EagerContext& ctx) {
|
||||
if (!ctx.PinSmallOpsToCPU() || IsFunction(op_name) ||
|
||||
IsColocationExempt(op_name) || !IsPinnableOp(op_name)) {
|
||||
*result = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Ops without inputs are usually ops that generate a tensor in some way and
|
||||
// usually require being present on whatever device they are scheduled on
|
||||
// - for e.g. VarHandleOp or _Recv).
|
||||
if (args.empty()) {
|
||||
*result = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (auto* arg : args) {
|
||||
Status s;
|
||||
const char* device_name = arg->DeviceName(&s);
|
||||
DataType dtype = arg->DataType();
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
if (IsCustomDevice(device_name, ctx)) {
|
||||
*result = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DVLOG(2) << "for op " << op_name << " input " << i << " "
|
||||
<< DataTypeString(dtype) << " input device = " << device_name;
|
||||
|
||||
// Input is on CPU.
|
||||
if (device_name != ctx.HostCPU()->name()) {
|
||||
*result = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (dtype != DataType::DT_INT32 && dtype != DataType::DT_INT64) {
|
||||
*result = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 num_elements;
|
||||
TF_RETURN_IF_ERROR(arg->NumElements(&num_elements));
|
||||
if (num_elements > 64) {
|
||||
*result = false;
|
||||
return Status::OK();
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
// TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
|
||||
// an op, but there is a GPU kernel?
|
||||
DVLOG(1) << "Forcing op " << op_name
|
||||
<< " to be on the CPU since all input tensors have an "
|
||||
"int32/int64 dtype, and are small (less than 64 elements).";
|
||||
*result = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybePinToResourceDevice(VariantDevice* device,
|
||||
const EagerOperation& op) {
|
||||
if (op.colocation_exempt()) {
|
||||
return Status::OK();
|
||||
}
|
||||
EagerContext& ctx = op.EagerContext();
|
||||
Device* op_device = op.Device() == kVariantDeviceNull
|
||||
? ctx.HostCPU()
|
||||
: absl::get<Device*>(op.Device());
|
||||
for (int i = 0; i < op.Inputs().size(); ++i) {
|
||||
TensorHandle* tensor_handle = op.Inputs()[i];
|
||||
if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
if (tensor_handle->resource_remote_device_incarnation() != 0) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
|
||||
&ctx, tensor_handle->resource_remote_device_incarnation()));
|
||||
}
|
||||
Device* resource_device = tensor_handle->resource_device();
|
||||
DVLOG(2) << "for op " << op.Name() << " input " << i << " "
|
||||
<< DataTypeString(tensor_handle->dtype)
|
||||
<< " input device = " << resource_device->name()
|
||||
<< ", op device = " << op_device->name();
|
||||
// We check for `op->Device() == nullptr` because it can be later
|
||||
// interpreted as unspecified device and a different device can
|
||||
// be selected based on device priority. If any input to an op
|
||||
// is a resource we must pin it to prevent different device selection.
|
||||
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
|
||||
if (resource_device != op_device || op.Device() == kVariantDeviceNull) {
|
||||
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
|
||||
<< "device of operation " << op.Name() << " to "
|
||||
<< resource_device->name() << " because input #" << i
|
||||
<< " is a resource in this device.";
|
||||
*device = resource_device;
|
||||
return Status::OK();
|
||||
// No point in looking at other inputs. If there are other resources,
|
||||
// they must have the same device and we already declared the op to be
|
||||
// ineligible for CPU pinning.
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
||||
// If operation was already placed on a custom device, use it.
|
||||
if (VariantDeviceIsCustom(op.Device())) {
|
||||
*device = op.Device();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (!op.Inputs().empty()) {
|
||||
// We keep track of what we've seen with devices instead of booleans to be
|
||||
// able to provide a meaningful error message below.
|
||||
VariantDevice first = op.Inputs()[0]->device();
|
||||
VariantDevice different = first; // A different input device, if any.
|
||||
VariantDevice custom = first; // The first custom device seen, or an
|
||||
// arbitrary non-custom device otherwise.
|
||||
for (size_t i = 1; first == different && i < op.Inputs().size(); ++i) {
|
||||
VariantDevice device = op.Inputs()[i]->device();
|
||||
if (device != first) {
|
||||
different = device;
|
||||
}
|
||||
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
|
||||
custom = device;
|
||||
}
|
||||
if (different != first && VariantDeviceIsCustom(custom)) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"If an operation has one of its inputs in a custom device, then "
|
||||
"all inputs should be on that same device. Operation ",
|
||||
op.Name(), " has one input in custom device ",
|
||||
VariantDeviceName(custom),
|
||||
" and at least one input in a different device ",
|
||||
VariantDeviceName(custom == first ? different : first)));
|
||||
}
|
||||
}
|
||||
if (different == first && VariantDeviceIsCustom(custom)) {
|
||||
*device = first;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
|
@ -0,0 +1,55 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
bool IsColocationExempt(StringPiece op_name);
|
||||
|
||||
bool IsFunction(StringPiece op_name);
|
||||
|
||||
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
|
||||
|
||||
// TODO(b/154234908): Unify placement logic.
|
||||
// TODO(b/159647422): Add C++ unit tests for placement logic.
|
||||
|
||||
// Pin the op to cpu if all op inputs are on the CPU, small (<64 elements) and
|
||||
// integers (int32/int64). This can be disabled by setting the environment
|
||||
// variable "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
|
||||
Status MaybePinSmallOpsToCpu(bool* result, StringPiece op_name,
|
||||
absl::Span<ImmediateExecutionTensorHandle*> args,
|
||||
const EagerContext& ctx);
|
||||
|
||||
// If a resource touching input is specified, all resource-touching ops run in
|
||||
// the device the resource is, regardless of anything else that has been
|
||||
// specified. This is identical to the graph mode behavior.
|
||||
Status MaybePinToResourceDevice(VariantDevice* device,
|
||||
const EagerOperation& op);
|
||||
|
||||
// If all the inputs are on the same custom device, use that custom
|
||||
// device. Otherwise, it is an error to have a custom device as an input.
|
||||
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
|
@ -491,7 +491,11 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
|||
|
||||
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
|
||||
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
|
||||
TF_RETURN_IF_ERROR(EagerExecute(&op, retvals.data(), &num_retvals));
|
||||
TF_RETURN_IF_ERROR(op.Execute(
|
||||
absl::MakeSpan(
|
||||
reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
|
||||
num_retvals),
|
||||
&num_retvals));
|
||||
|
||||
return AddOpRetvalsToResponse(
|
||||
eager_context, operation.id(), num_retvals, retvals.data(),
|
||||
|
|
|
@ -331,9 +331,12 @@ tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
|
|||
node_data->mutable_outputs()->ResetTensorHandles();
|
||||
int num_retvals = node_data->NumOutputs();
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
EagerExecute(node_data->op(),
|
||||
node_data->mutable_outputs()->GetTensorHandles()->data(),
|
||||
&num_retvals),
|
||||
node_data->op()->Execute(
|
||||
absl::MakeSpan(
|
||||
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
node_data->mutable_outputs()->GetTensorHandles()->data()),
|
||||
num_retvals),
|
||||
&num_retvals),
|
||||
" (while executing '", node_data->name(), "' via Eager)");
|
||||
|
||||
if (num_retvals != node_data->NumOutputs()) {
|
||||
|
|
Loading…
Reference in New Issue