Added reader for Transpose operation.
Added selector (OpenCL) for Transpose kernel. PiperOrigin-RevId: 274511551
This commit is contained in:
parent
e2403531aa
commit
acb32b90ef
@ -119,6 +119,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax1x1",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:softmax1x1",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:strided_slice",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:strided_slice",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:upsample",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:upsample",
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
@ -156,6 +156,12 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
|||||||
SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op);
|
SelectSoftmax(inputs[0]->tensor.shape, op_def, gpu_op);
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
case OperationType::TRANSPOSE: {
|
||||||
|
auto attr =
|
||||||
|
absl::any_cast<TransposeAttributes>(node.operation.attributes);
|
||||||
|
SelectTranspose(attr, op_def, gpu_op);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
case OperationType::UPSAMPLE_2D: {
|
case OperationType::UPSAMPLE_2D: {
|
||||||
auto attr =
|
auto attr =
|
||||||
absl::any_cast<Upsample2DAttributes>(node.operation.attributes);
|
absl::any_cast<Upsample2DAttributes>(node.operation.attributes);
|
||||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/upsample.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/upsample.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -191,6 +192,13 @@ void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SelectTranspose(const TransposeAttributes& attr,
|
||||||
|
const OperationDef& op_def,
|
||||||
|
std::unique_ptr<GPUOperation>* ptr) {
|
||||||
|
Transpose operation = CreateTranspose(op_def, attr);
|
||||||
|
*ptr = absl::make_unique<Transpose>(std::move(operation));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -89,6 +89,10 @@ Status SelectBroadcastAdd(const AddAttributes& attr,
|
|||||||
void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
|
void SelectSoftmax(const BHWC& shape, const OperationDef& op_def,
|
||||||
std::unique_ptr<GPUOperation>* ptr);
|
std::unique_ptr<GPUOperation>* ptr);
|
||||||
|
|
||||||
|
void SelectTranspose(const TransposeAttributes& attr,
|
||||||
|
const OperationDef& op_def,
|
||||||
|
std::unique_ptr<GPUOperation>* ptr);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -1983,6 +1983,43 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TransposeOperationParser : public TFLiteOperationParser {
|
||||||
|
public:
|
||||||
|
Status IsSupported(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration) final {
|
||||||
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Parse(const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||||
|
ObjectReader* reader) final {
|
||||||
|
Node* node = graph->NewNode();
|
||||||
|
node->operation.type = ToString(OperationType::TRANSPOSE);
|
||||||
|
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||||
|
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||||
|
|
||||||
|
TransposeAttributes attr;
|
||||||
|
Tensor<Linear, DataType::INT32> perm;
|
||||||
|
RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
|
||||||
|
if (perm.data.size() == 4) {
|
||||||
|
attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[0], perm.data[3]);
|
||||||
|
} else if (perm.data.size() == 3) {
|
||||||
|
attr.perm = BHWC(0, perm.data[0] + 1, perm.data[1] + 1, perm.data[2] + 1);
|
||||||
|
} else if (perm.data.size() == 2) {
|
||||||
|
attr.perm = BHWC(0, 1, perm.data[0] + 2, perm.data[1] + 2);
|
||||||
|
} else {
|
||||||
|
return InvalidArgumentError("Permutation for transpose is invalid.");
|
||||||
|
}
|
||||||
|
|
||||||
|
node->operation.attributes = attr;
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class Unpooling2DOperationParser : public TFLiteOperationParser {
|
class Unpooling2DOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
@ -2200,6 +2237,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||||||
return absl::make_unique<ElementwiseOperationParser>(OperationType::SUB);
|
return absl::make_unique<ElementwiseOperationParser>(OperationType::SUB);
|
||||||
case kTfLiteBuiltinTanh:
|
case kTfLiteBuiltinTanh:
|
||||||
return absl::make_unique<ElementwiseOperationParser>(OperationType::TANH);
|
return absl::make_unique<ElementwiseOperationParser>(OperationType::TANH);
|
||||||
|
case kTfLiteBuiltinTranspose:
|
||||||
|
return absl::make_unique<TransposeOperationParser>();
|
||||||
case kTfLiteBuiltinTransposeConv:
|
case kTfLiteBuiltinTransposeConv:
|
||||||
return absl::make_unique<TransposeConvOperationParser>();
|
return absl::make_unique<TransposeConvOperationParser>();
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user