Refactor eager placement logic into three util methods:

- MaybePinSmallOpsToCpu
- MaybePinToResourceDevice
- MaybePinToCustomDevice

We are going to reuse MaybePinSmallOpsToCpu in TFRT but not the other two. Because TFRT doesn't have native Resource neither Custom Device.

PiperOrigin-RevId: 317766813
Change-Id: I43241b5786120ddf39dc4bfff6071239afdfd785
This commit is contained in:
Xiao Yu 2020-06-22 17:15:42 -07:00 committed by TensorFlower Gardener
parent c692c45dae
commit fe6e64b098
9 changed files with 349 additions and 175 deletions

View File

@ -29,6 +29,7 @@ tf_cuda_library(
":context", ":context",
":eager_operation", ":eager_operation",
":execute", ":execute",
":placement_utils",
":tensor_handle", ":tensor_handle",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
@ -489,6 +490,7 @@ cc_library(
":eager_op_rewrite_registry", ":eager_op_rewrite_registry",
":eager_operation", ":eager_operation",
":kernel_and_device", ":kernel_and_device",
":placement_utils",
":tensor_handle", ":tensor_handle",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
@ -521,6 +523,35 @@ cc_library(
}), }),
) )
tf_cuda_library(
name = "placement_utils",
srcs = [
"placement_utils.cc",
],
hdrs = [
"placement_utils.h",
],
visibility = ["//tensorflow:internal"],
deps = [
":context",
":attr_builder",
":eager_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
}),
)
tf_cuda_library( tf_cuda_library(
name = "attr_builder", name = "attr_builder",
srcs = ["attr_builder.cc"], srcs = ["attr_builder.cc"],

View File

@ -478,7 +478,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
// On mobile, it just cleans the caches. // On mobile, it just cleans the caches.
void WaitForAndCloseRemoteContexts(); void WaitForAndCloseRemoteContexts();
bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; } bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
tensorflow::Env* TFEnv() const { return env_; } tensorflow::Env* TFEnv() const { return env_; }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
@ -187,6 +188,27 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) {
// eager_operation.cc we can avoid a circular dependency between them. // eager_operation.cc we can avoid a circular dependency between them.
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals, Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) { int* num_retvals) {
// Run eager placement logic.
VariantDevice device;
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));
if (device == kVariantDeviceNull) {
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
}
if (device == kVariantDeviceNull) {
bool pin_to_cpu;
TF_RETURN_IF_ERROR(eager::MaybePinSmallOpsToCpu(
&pin_to_cpu, op_name(),
absl::MakeSpan(
reinterpret_cast<ImmediateExecutionTensorHandle**>(inputs_.data()),
inputs_.size()),
ctx_));
if (pin_to_cpu) {
device = ctx_.HostCPU();
}
}
if (device != kVariantDeviceNull) {
SetDevice(device);
}
return EagerExecute( return EagerExecute(
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()), this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
num_retvals); num_retvals);

View File

@ -126,7 +126,7 @@ class EagerOperation : public ImmediateExecutionOperation {
bool is_function() const { return is_function_; } bool is_function() const { return is_function_; }
bool colocation_exempt() const { return colocation_exempt_; } bool colocation_exempt() const { return colocation_exempt_; }
tensorflow::EagerContext& EagerContext() { return ctx_; } tensorflow::EagerContext& EagerContext() const { return ctx_; }
AttrBuilder* MutableAttrs() { return &attrs_; } AttrBuilder* MutableAttrs() { return &attrs_; }
const AttrBuilder& Attrs() const { return attrs_; } const AttrBuilder& Attrs() const { return attrs_; }

View File

@ -870,173 +870,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
} }
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
// These ops are not pinnable since they generate data. It can be slower to
// generate and then copy the data instead of just generating the data on the
// device directly.
bool IsPinnableOp(const string& op_type) {
static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
"RandomUniform",
"RandomUniformInt",
"RandomStandardNormal",
"StatelessRandomUniform",
"StatelessRandomUniformInt",
"StatelessRandomUniformFullInt",
"StatelessRandomNormal",
});
// XRT ops refer to per-device handles that are not safe to move between
// devices.
return unpinnable_ops->find(op_type) == unpinnable_ops->end() &&
!absl::StartsWith(op_type, "XRT");
}
// Validate if the remote device with the given incarnation is valid in the
// remote device manager of the current eager context.
Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
int64 device_incarnation) {
if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
return Status::OK();
}
return errors::InvalidArgument(
"Resource input tensor contains an invalid device. This might happen "
"when the client has connected to a different cluster, or some remote "
"workers have been restarted.");
}
// The Op device may be updated if:
// - A resource touching input is specified: all resource-touching ops run in
// the device the resource is, regardless of anything else that has been
// specified. This is identical to the graph mode behavior.
//
// - All op inputs are on the CPU, small (<64 elements) and integers
// (int32/int64). This can be disabled by setting the environment variable
// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
//
// TODO(b/154234908): Unify placement logic.
Status MaybeUpdateOpDevice(EagerOperation* op) {
// If operation was already placed on a custom device, use it.
if (VariantDeviceIsCustom(op->Device())) {
return Status::OK();
}
// If all the inputs are on the same custom device, use that custom
// device. Otherwise, it is an error to have a custom device as an input.
if (!op->Inputs().empty()) {
// We keep track of what we've seen with devices instead of booleans to be
// able to provide a meaningful error message below.
VariantDevice first = op->Inputs()[0]->device();
VariantDevice different = first; // A different input device, if any.
VariantDevice custom = first; // The first custom device seen, or an
// arbitrary non-custom device otherwise.
for (size_t i = 1; first == different && i < op->Inputs().size(); ++i) {
VariantDevice device = op->Inputs()[i]->device();
if (device != first) {
different = device;
}
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
custom = device;
}
if (different != first && VariantDeviceIsCustom(custom)) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same device. Operation ",
op->Name(), " has one input in custom device ",
VariantDeviceName(custom),
" and at least one input in a different device ",
VariantDeviceName(custom == first ? different : first)));
}
}
if (different == first && VariantDeviceIsCustom(custom)) {
op->SetDevice(first);
return Status::OK();
}
}
if (op->colocation_exempt()) {
return Status::OK();
}
EagerContext& ctx = op->EagerContext();
bool all_inputs_eligible_for_cpu_pinning =
ctx.PinSmallOpsToCPU() && !op->is_function() && IsPinnableOp(op->Name());
Device* op_device = op->Device() == kVariantDeviceNull
? ctx.HostCPU()
: absl::get<Device*>(op->Device());
for (int i = 0; i < op->Inputs().size(); ++i) {
TensorHandle* tensor_handle = op->Inputs()[i];
if (tensor_handle->dtype == DT_RESOURCE) {
if (tensor_handle->resource_remote_device_incarnation() != 0) {
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
&ctx, tensor_handle->resource_remote_device_incarnation()));
}
Device* resource_device = tensor_handle->resource_device();
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
<< DataTypeString(tensor_handle->dtype)
<< " input device = " << resource_device->name()
<< ", op device = " << op_device->name();
// We check for `op->Device() == nullptr` because it can be later
// interpreted as unspecified device and a different device can
// be selected based on device priority. If any input to an op
// is a resource we must pin it to prevent different device selection.
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
if (resource_device != op_device || op->Device() == kVariantDeviceNull) {
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
<< "device of operation " << op->Name() << " to "
<< resource_device->name() << " because input #" << i
<< " is a resource in this device.";
op->SetDevice(resource_device);
}
all_inputs_eligible_for_cpu_pinning = false;
// No point in looking at other inputs. If there are other resources,
// they must have the same device and we already declared the op to be
// ineligible for CPU pinning.
break;
} else if (all_inputs_eligible_for_cpu_pinning) {
auto input_device_variant = tensor_handle->DeviceOrHostCPU(ctx);
if (VariantDeviceIsCustom(input_device_variant)) {
all_inputs_eligible_for_cpu_pinning = false;
continue;
}
Device* input_device = absl::get<Device*>(input_device_variant);
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
<< DataTypeString(tensor_handle->dtype)
<< " input device = " << input_device->name()
<< ", op device = " << op_device->name();
// Input is on CPU.
if (input_device != ctx.HostCPU()) {
all_inputs_eligible_for_cpu_pinning = false;
continue;
}
if (tensor_handle->dtype != DataType::DT_INT32 &&
tensor_handle->dtype != DataType::DT_INT64) {
all_inputs_eligible_for_cpu_pinning = false;
continue;
}
int64 num_elements;
TF_RETURN_IF_ERROR(tensor_handle->NumElements(&num_elements));
if (num_elements > 64) {
all_inputs_eligible_for_cpu_pinning = false;
}
}
}
// Ops without inputs are usually ops that generate a tensor in some way and
// usually require being present on whatever device they are scheduled on
// - for e.g. VarHandleOp or _Recv).
// TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
// an op, but there is a GPU kernel?
if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
DVLOG(1) << "Forcing op " << op->Name()
<< " to be on the CPU since all input tensors have an "
"int32/int64 dtype, and are small (less than 64 elements).";
op->SetDevice(ctx.HostCPU());
}
return Status::OK();
}
Status GetKernelOutputs(std::vector<Tensor>* outputs, int num_outputs, Status GetKernelOutputs(std::vector<Tensor>* outputs, int num_outputs,
TensorHandle** retvals, EagerContext* ctx, TensorHandle** retvals, EagerContext* ctx,
KernelAndDevice* kernel) { KernelAndDevice* kernel) {
@ -1099,8 +932,6 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
[&] { return absl::StrCat("EagerExecute: ", op->Name()); }, [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
if (VariantDeviceIsCustom(op->Device())) { if (VariantDeviceIsCustom(op->Device())) {
return absl::get<CustomDevice*>(op->Device()) return absl::get<CustomDevice*>(op->Device())
->Execute(op, retvals, num_retvals); ->Execute(op, retvals, num_retvals);

View File

@ -0,0 +1,228 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace eager {
// These ops are not pinnable since they generate data. It can be slower to
// generate and then copy the data instead of just generating the data on the
// device directly.
static bool IsPinnableOp(StringPiece op_name) {
static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
"RandomUniform",
"RandomUniformInt",
"RandomStandardNormal",
"StatelessRandomUniform",
"StatelessRandomUniformInt",
"StatelessRandomUniformFullInt",
"StatelessRandomNormal",
});
// XRT ops refer to per-device handles that are not safe to move between
// devices.
return unpinnable_ops->find(string(op_name)) == unpinnable_ops->end() &&
!absl::StartsWith(op_name, "XRT");
}
// Validate if the remote device with the given incarnation is valid in the
// remote device manager of the current eager context.
static Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
int64 device_incarnation) {
if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
return Status::OK();
}
return errors::InvalidArgument(
"Resource input tensor contains an invalid device. This might happen "
"when the client has connected to a different cluster, or some remote "
"workers have been restarted.");
}
bool IsColocationExempt(StringPiece op_name) {
const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
return exempt_ops.find(string(op_name)) != exempt_ops.end();
}
bool IsFunction(StringPiece op_name) {
const OpDef* op_def = nullptr;
Status s = OpDefForOp(string(op_name), &op_def);
if (!s.ok()) {
if (!errors::IsNotFound(s)) {
LOG(WARNING) << "Looking up OpDef failed with error: " << s.ToString();
}
// Cannot find OpDef, it is a function.
return true;
}
return false;
}
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
CustomDevice* custom_device;
return ctx.FindCustomDeviceFromName(string(device_name), &custom_device).ok();
}
Status MaybePinSmallOpsToCpu(bool* result, StringPiece op_name,
absl::Span<ImmediateExecutionTensorHandle*> args,
const EagerContext& ctx) {
if (!ctx.PinSmallOpsToCPU() || IsFunction(op_name) ||
IsColocationExempt(op_name) || !IsPinnableOp(op_name)) {
*result = false;
return Status::OK();
}
// Ops without inputs are usually ops that generate a tensor in some way and
// usually require being present on whatever device they are scheduled on
// - for e.g. VarHandleOp or _Recv).
if (args.empty()) {
*result = false;
return Status::OK();
}
int i = 0;
for (auto* arg : args) {
Status s;
const char* device_name = arg->DeviceName(&s);
DataType dtype = arg->DataType();
TF_RETURN_IF_ERROR(s);
if (IsCustomDevice(device_name, ctx)) {
*result = false;
return Status::OK();
}
DVLOG(2) << "for op " << op_name << " input " << i << " "
<< DataTypeString(dtype) << " input device = " << device_name;
// Input is on CPU.
if (device_name != ctx.HostCPU()->name()) {
*result = false;
return Status::OK();
}
if (dtype != DataType::DT_INT32 && dtype != DataType::DT_INT64) {
*result = false;
return Status::OK();
}
int64 num_elements;
TF_RETURN_IF_ERROR(arg->NumElements(&num_elements));
if (num_elements > 64) {
*result = false;
return Status::OK();
}
i++;
}
// TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
// an op, but there is a GPU kernel?
DVLOG(1) << "Forcing op " << op_name
<< " to be on the CPU since all input tensors have an "
"int32/int64 dtype, and are small (less than 64 elements).";
*result = true;
return Status::OK();
}
Status MaybePinToResourceDevice(VariantDevice* device,
const EagerOperation& op) {
if (op.colocation_exempt()) {
return Status::OK();
}
EagerContext& ctx = op.EagerContext();
Device* op_device = op.Device() == kVariantDeviceNull
? ctx.HostCPU()
: absl::get<Device*>(op.Device());
for (int i = 0; i < op.Inputs().size(); ++i) {
TensorHandle* tensor_handle = op.Inputs()[i];
if (tensor_handle->dtype == DT_RESOURCE) {
if (tensor_handle->resource_remote_device_incarnation() != 0) {
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
&ctx, tensor_handle->resource_remote_device_incarnation()));
}
Device* resource_device = tensor_handle->resource_device();
DVLOG(2) << "for op " << op.Name() << " input " << i << " "
<< DataTypeString(tensor_handle->dtype)
<< " input device = " << resource_device->name()
<< ", op device = " << op_device->name();
// We check for `op->Device() == nullptr` because it can be later
// interpreted as unspecified device and a different device can
// be selected based on device priority. If any input to an op
// is a resource we must pin it to prevent different device selection.
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
if (resource_device != op_device || op.Device() == kVariantDeviceNull) {
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
<< "device of operation " << op.Name() << " to "
<< resource_device->name() << " because input #" << i
<< " is a resource in this device.";
*device = resource_device;
return Status::OK();
// No point in looking at other inputs. If there are other resources,
// they must have the same device and we already declared the op to be
// ineligible for CPU pinning.
}
}
}
return Status::OK();
}
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
// If operation was already placed on a custom device, use it.
if (VariantDeviceIsCustom(op.Device())) {
*device = op.Device();
return Status::OK();
}
if (!op.Inputs().empty()) {
// We keep track of what we've seen with devices instead of booleans to be
// able to provide a meaningful error message below.
VariantDevice first = op.Inputs()[0]->device();
VariantDevice different = first; // A different input device, if any.
VariantDevice custom = first; // The first custom device seen, or an
// arbitrary non-custom device otherwise.
for (size_t i = 1; first == different && i < op.Inputs().size(); ++i) {
VariantDevice device = op.Inputs()[i]->device();
if (device != first) {
different = device;
}
if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
custom = device;
}
if (different != first && VariantDeviceIsCustom(custom)) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same device. Operation ",
op.Name(), " has one input in custom device ",
VariantDeviceName(custom),
" and at least one input in a different device ",
VariantDeviceName(custom == first ? different : first)));
}
}
if (different == first && VariantDeviceIsCustom(custom)) {
*device = first;
return Status::OK();
}
}
return Status::OK();
}
} // namespace eager
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
namespace tensorflow {
namespace eager {
bool IsColocationExempt(StringPiece op_name);
bool IsFunction(StringPiece op_name);
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
// TODO(b/154234908): Unify placement logic.
// TODO(b/159647422): Add C++ unit tests for placement logic.
// Pin the op to cpu if all op inputs are on the CPU, small (<64 elements) and
// integers (int32/int64). This can be disabled by setting the environment
// variable "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
Status MaybePinSmallOpsToCpu(bool* result, StringPiece op_name,
absl::Span<ImmediateExecutionTensorHandle*> args,
const EagerContext& ctx);
// If a resource touching input is specified, all resource-touching ops run in
// the device the resource is, regardless of anything else that has been
// specified. This is identical to the graph mode behavior.
Status MaybePinToResourceDevice(VariantDevice* device,
const EagerOperation& op);
// If all the inputs are on the same custom device, use that custom
// device. Otherwise, it is an error to have a custom device as an input.
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
} // namespace eager
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_

View File

@ -491,7 +491,11 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals); absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id(); VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
TF_RETURN_IF_ERROR(EagerExecute(&op, retvals.data(), &num_retvals)); TF_RETURN_IF_ERROR(op.Execute(
absl::MakeSpan(
reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
num_retvals),
&num_retvals));
return AddOpRetvalsToResponse( return AddOpRetvalsToResponse(
eager_context, operation.id(), num_retvals, retvals.data(), eager_context, operation.id(), num_retvals, retvals.data(),

View File

@ -331,9 +331,12 @@ tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
node_data->mutable_outputs()->ResetTensorHandles(); node_data->mutable_outputs()->ResetTensorHandles();
int num_retvals = node_data->NumOutputs(); int num_retvals = node_data->NumOutputs();
TF_RETURN_WITH_CONTEXT_IF_ERROR( TF_RETURN_WITH_CONTEXT_IF_ERROR(
EagerExecute(node_data->op(), node_data->op()->Execute(
node_data->mutable_outputs()->GetTensorHandles()->data(), absl::MakeSpan(
&num_retvals), reinterpret_cast<tensorflow::AbstractTensorHandle**>(
node_data->mutable_outputs()->GetTensorHandles()->data()),
num_retvals),
&num_retvals),
" (while executing '", node_data->name(), "' via Eager)"); " (while executing '", node_data->name(), "' via Eager)");
if (num_retvals != node_data->NumOutputs()) { if (num_retvals != node_data->NumOutputs()) {