Added reader for Transpose operation.

Added selector (OpenCL) for Transpose kernel.

PiperOrigin-RevId: 274511551
This commit is contained in:
A. Unique TensorFlower 2019-10-13 23:02:59 -07:00 committed by TensorFlower Gardener
parent e2403531aa
commit acb32b90ef
5 changed files with 58 additions and 0 deletions

View File

@ -119,6 +119,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax",
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax1x1",
"//tensorflow/lite/delegates/gpu/cl/kernels:strided_slice",
"//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
"//tensorflow/lite/delegates/gpu/cl/kernels:upsample",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",

View File

@ -156,6 +156,12 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op);
return OkStatus();
}
case OperationType::TRANSPOSE: {
auto attr =
absl::any_cast<TransposeAttributes>(node.operation.attributes);
SelectTranspose(attr, op_def, gpu_op);
return OkStatus();
}
case OperationType::UPSAMPLE_2D: {
auto attr =
absl::any_cast<Upsample2DAttributes>(node.operation.attributes);

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/upsample.h"
namespace tflite {
@ -191,6 +192,13 @@ void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
}
}
void SelectTranspose(const TransposeAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
Transpose operation = CreateTranspose(op_def, attr);
*ptr = absl::make_unique<Transpose>(std::move(operation));
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -89,6 +89,10 @@ Status SelectBroadcastAdd(const AddAttributes& attr,
void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
void SelectTranspose(const TransposeAttributes& attr,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -1983,6 +1983,43 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
}
};
class TransposeOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
return OkStatus();
}
Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration, GraphFloat32* graph,
ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::TRANSPOSE);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
TransposeAttributes attr;
Tensor<Linear, DataType::INT32> perm;
RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
if (perm.data.size() == 4) {
attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[0], perm.data[3]);
} else if (perm.data.size() == 3) {
attr.perm = BHWC(0, perm.data[0] + 1, perm.data[1] + 1, perm.data[2] + 1);
} else if (perm.data.size() == 2) {
attr.perm = BHWC(0, 1, perm.data[0] + 2, perm.data[1] + 2);
} else {
return InvalidArgumentError("Permutation for transpose is invalid.");
}
node->operation.attributes = attr;
return OkStatus();
}
};
class Unpooling2DOperationParser : public TFLiteOperationParser {
public:
Status IsSupported(const TfLiteContext* context,
@ -2200,6 +2237,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
return absl::make_unique<ElementwiseOperationParser>(OperationType::SUB);
case kTfLiteBuiltinTanh:
return absl::make_unique<ElementwiseOperationParser>(OperationType::TANH);
case kTfLiteBuiltinTranspose:
return absl::make_unique<TransposeOperationParser>();
case kTfLiteBuiltinTransposeConv:
return absl::make_unique<TransposeConvOperationParser>();