Adding tfl.prelu op, verifier, and tests
PiperOrigin-RevId: 270795743
This commit is contained in:
parent
953e80a3d1
commit
882b91d247
@ -625,6 +625,46 @@ static LogicalResult Verify(PackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PReluOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(PReluOp op) {
|
||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
||||
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
|
||||
auto output_type = op.output()->getType().cast<ShapedType>();
|
||||
|
||||
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
|
||||
if (input_type.getRank() != alpha_type.getRank() + 1) {
|
||||
return op.emitOpError("'alpha' should have one less rank than 'input'.");
|
||||
}
|
||||
|
||||
// Check if alpha is broadcastable
|
||||
for (int i = 0; i < alpha_type.getRank(); i++) {
|
||||
if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
|
||||
alpha_type.getDimSize(i) != 1) {
|
||||
return op.emitOpError(
|
||||
llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
|
||||
if (input_type.getRank() != output_type.getRank()) {
|
||||
return op.emitOpError("'input' and 'output' should have the same rank.");
|
||||
}
|
||||
|
||||
// Check if input and output shapes are same
|
||||
for (int i = 0; i < input_type.getRank(); i++) {
|
||||
if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
|
||||
return op.emitOpError(
|
||||
"'input' and 'output' should have the same shape.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1749,6 +1749,28 @@ def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect, NoQuantizableResult]
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> {
|
||||
let summary = "Parameterized Relu operator";
|
||||
|
||||
let description = [{
|
||||
Parameterized Relu operator
|
||||
x -> x >= 0 ? x : (alpha * x)
|
||||
where alpha is a trainable tensor.
|
||||
alpha should have one less rank than the input as it doesn't have the batch
|
||||
dimension, and the other dimensions either should be the same size as input
|
||||
or size 1, where it is broadcasted in the second case.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TensorOf<[F32, QUI8]>:$input,
|
||||
TensorOf<[F32, QUI8]>:$alpha
|
||||
);
|
||||
|
||||
let results = (outs TensorOf<[F32, QUI8]>:$output);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
|
||||
let summary = "Rank operator.";
|
||||
let description = [{
|
||||
|
@ -1418,6 +1418,66 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> {
|
||||
// expected-error @+1 {{'input' and 'output' should have the same rank}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32>
|
||||
return %0 : tensor<10x10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
|
||||
// expected-error @+1 {{'input' and 'output' should have the same shape}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
|
||||
return %0 : tensor<1x2x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
||||
// expected-error @+1 {{'alpha' should have one less rank than 'input'.}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
||||
return %0 : tensor<7x3x2x14xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
|
||||
// expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
|
||||
return %0 : tensor<15x14x2x14xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluValidSameSize(%arg0: tensor<16x20x20x13xf32>, %arg1: tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32> {
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<16x20x20x13xf32>, tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32>
|
||||
return %0 : tensor<16x20x20x13xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluValidBroadcast(%arg0: tensor<19x7x12x14xf32>, %arg1: tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32> {
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<19x7x12x14xf32>, tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32>
|
||||
return %0 : tensor<19x7x12x14xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluValidFullBroadcast(%arg0: tensor<7x8x9x10xf32>, %arg1: tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32> {
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x8x9x10xf32>, tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32>
|
||||
return %0 : tensor<7x8x9x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluValidQuantized(%arg0: tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>, %arg1: tensor<1x1x16x!quant.uniform<u8<1:255>:f32, 0.004846:14>>) -> tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>> {
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>, tensor<1x1x16x!quant.uniform<u8<1:255>:f32, 0.004846:14>>) -> tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>
|
||||
return %0 : tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
|
Loading…
Reference in New Issue
Block a user