Added reader for Transpose operation.
Added selector (OpenCL) for Transpose kernel. PiperOrigin-RevId: 274511551
This commit is contained in:
parent
e2403531aa
commit
acb32b90ef
@ -119,6 +119,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax1x1",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:strided_slice",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:upsample",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
|
@ -156,6 +156,12 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op);
|
||||
return OkStatus();
|
||||
}
|
||||
case OperationType::TRANSPOSE: {
|
||||
auto attr =
|
||||
absl::any_cast<TransposeAttributes>(node.operation.attributes);
|
||||
SelectTranspose(attr, op_def, gpu_op);
|
||||
return OkStatus();
|
||||
}
|
||||
case OperationType::UPSAMPLE_2D: {
|
||||
auto attr =
|
||||
absl::any_cast<Upsample2DAttributes>(node.operation.attributes);
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/upsample.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -191,6 +192,13 @@ void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
|
||||
}
|
||||
}
|
||||
|
||||
void SelectTranspose(const TransposeAttributes& attr,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
Transpose operation = CreateTranspose(op_def, attr);
|
||||
*ptr = absl::make_unique<Transpose>(std::move(operation));
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -89,6 +89,10 @@ Status SelectBroadcastAdd(const AddAttributes& attr,
|
||||
void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
void SelectTranspose(const TransposeAttributes& attr,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -1983,6 +1983,43 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||
RETURN_IF_ERROR(
|
||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||
ObjectReader* reader) final {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::TRANSPOSE);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
TransposeAttributes attr;
|
||||
Tensor<Linear, DataType::INT32> perm;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
|
||||
if (perm.data.size() == 4) {
|
||||
attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[0], perm.data[3]);
|
||||
} else if (perm.data.size() == 3) {
|
||||
attr.perm = BHWC(0, perm.data[0] + 1, perm.data[1] + 1, perm.data[2] + 1);
|
||||
} else if (perm.data.size() == 2) {
|
||||
attr.perm = BHWC(0, 1, perm.data[0] + 2, perm.data[1] + 2);
|
||||
} else {
|
||||
return InvalidArgumentError("Permutation for transpose is invalid.");
|
||||
}
|
||||
|
||||
node->operation.attributes = attr;
|
||||
return OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
class Unpooling2DOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
Status IsSupported(const TfLiteContext* context,
|
||||
@ -2200,6 +2237,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
return absl::make_unique<ElementwiseOperationParser>(OperationType::SUB);
|
||||
case kTfLiteBuiltinTanh:
|
||||
return absl::make_unique<ElementwiseOperationParser>(OperationType::TANH);
|
||||
case kTfLiteBuiltinTranspose:
|
||||
return absl::make_unique<TransposeOperationParser>();
|
||||
case kTfLiteBuiltinTransposeConv:
|
||||
return absl::make_unique<TransposeConvOperationParser>();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user