Added Transpose support to Metal backend.
PiperOrigin-RevId: 352867205 Change-Id: I5161ac760ff887b10db01c1320237a109fe98926
This commit is contained in:
parent
9a426abe81
commit
2263b49a6e
@ -36,6 +36,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:prelu",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:relu",
|
||||
"//tensorflow/lite/delegates/gpu/common/tasks:transpose",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||
"//tensorflow/lite/delegates/gpu/metal/kernels",
|
||||
],
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/prelu.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/quantize_and_dequantize.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/relu.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tasks/transpose.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
@ -135,6 +136,13 @@ std::unique_ptr<ComputeTaskDescriptor> SelectSpaceToDepth(
|
||||
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
}
|
||||
|
||||
void SelectTranspose(const TransposeAttributes& attr,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
GPUOperation operation = CreateTranspose(op_def, attr);
|
||||
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
|
||||
}
|
||||
|
||||
std::unique_ptr<ComputeTaskDescriptor> SelectWinograd4x4To36(
|
||||
const OperationDef& op_def, const Winograd4x4To36Attributes& attr,
|
||||
const GpuInfo& gpu_info) {
|
||||
@ -452,6 +460,12 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
op_def,
|
||||
absl::any_cast<SpaceToDepthAttributes>(node.operation.attributes));
|
||||
break;
|
||||
case OperationType::TRANSPOSE: {
|
||||
auto attr =
|
||||
absl::any_cast<TransposeAttributes>(node.operation.attributes);
|
||||
SelectTranspose(attr, op_def, &gpu_operation->operation);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::ABS:
|
||||
case OperationType::COPY:
|
||||
case OperationType::COS:
|
||||
@ -515,7 +529,6 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
|
||||
case OperationType::REDUCE_PRODUCT:
|
||||
case OperationType::REDUCE_SUM:
|
||||
case OperationType::SPACE_TO_BATCH:
|
||||
case OperationType::TRANSPOSE:
|
||||
return absl::UnimplementedError("Unsupported op: " + node.operation.type);
|
||||
default:
|
||||
return SelectDefault(gpu_info, op_def, inputs, outputs, node,
|
||||
|
Loading…
Reference in New Issue
Block a user