Added Transpose support to Metal backend.

PiperOrigin-RevId: 352867205
Change-Id: I5161ac760ff887b10db01c1320237a109fe98926
This commit is contained in:
Raman Sarokin 2021-01-20 13:47:22 -08:00 committed by TensorFlower Gardener
parent 9a426abe81
commit 2263b49a6e
2 changed files with 15 additions and 1 deletions

View File

@ -36,6 +36,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common/tasks:prelu", "//tensorflow/lite/delegates/gpu/common/tasks:prelu",
"//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize", "//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize",
"//tensorflow/lite/delegates/gpu/common/tasks:relu", "//tensorflow/lite/delegates/gpu/common/tasks:relu",
"//tensorflow/lite/delegates/gpu/common/tasks:transpose",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal/kernels", "//tensorflow/lite/delegates/gpu/metal/kernels",
], ],

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/tasks/prelu.h" #include "tensorflow/lite/delegates/gpu/common/tasks/prelu.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/quantize_and_dequantize.h" #include "tensorflow/lite/delegates/gpu/common/tasks/quantize_and_dequantize.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/relu.h" #include "tensorflow/lite/delegates/gpu/common/tasks/relu.h"
#include "tensorflow/lite/delegates/gpu/common/tasks/transpose.h"
#include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/common/winograd_util.h" #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -135,6 +136,13 @@ std::unique_ptr<ComputeTaskDescriptor> SelectSpaceToDepth(
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op)); return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
} }
void SelectTranspose(const TransposeAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
GPUOperation operation = CreateTranspose(op_def, attr);
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
}
std::unique_ptr<ComputeTaskDescriptor> SelectWinograd4x4To36( std::unique_ptr<ComputeTaskDescriptor> SelectWinograd4x4To36(
const OperationDef& op_def, const Winograd4x4To36Attributes& attr, const OperationDef& op_def, const Winograd4x4To36Attributes& attr,
const GpuInfo& gpu_info) { const GpuInfo& gpu_info) {
@ -452,6 +460,12 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
op_def, op_def,
absl::any_cast<SpaceToDepthAttributes>(node.operation.attributes)); absl::any_cast<SpaceToDepthAttributes>(node.operation.attributes));
break; break;
case OperationType::TRANSPOSE: {
auto attr =
absl::any_cast<TransposeAttributes>(node.operation.attributes);
SelectTranspose(attr, op_def, &gpu_operation->operation);
return absl::OkStatus();
}
case OperationType::ABS: case OperationType::ABS:
case OperationType::COPY: case OperationType::COPY:
case OperationType::COS: case OperationType::COS:
@ -515,7 +529,6 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
case OperationType::REDUCE_PRODUCT: case OperationType::REDUCE_PRODUCT:
case OperationType::REDUCE_SUM: case OperationType::REDUCE_SUM:
case OperationType::SPACE_TO_BATCH: case OperationType::SPACE_TO_BATCH:
case OperationType::TRANSPOSE:
return absl::UnimplementedError("Unsupported op: " + node.operation.type); return absl::UnimplementedError("Unsupported op: " + node.operation.type);
default: default:
return SelectDefault(gpu_info, op_def, inputs, outputs, node, return SelectDefault(gpu_info, op_def, inputs, outputs, node,