Adding tfl.prelu op, verifier, and tests

PiperOrigin-RevId: 270795743
This commit is contained in:
Abdurrahman Akkas 2019-09-23 17:17:57 -07:00 committed by TensorFlower Gardener
parent 953e80a3d1
commit 882b91d247
3 changed files with 122 additions and 0 deletions

View File

@ -625,6 +625,46 @@ static LogicalResult Verify(PackOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// PReluOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(PReluOp op) {
auto input_type = op.input()->getType().cast<ShapedType>();
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
auto output_type = op.output()->getType().cast<ShapedType>();
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
if (input_type.getRank() != alpha_type.getRank() + 1) {
return op.emitOpError("'alpha' should have one less rank than 'input'.");
}
// Check if alpha is broadcastable
for (int i = 0; i < alpha_type.getRank(); i++) {
if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
alpha_type.getDimSize(i) != 1) {
return op.emitOpError(
llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
}
}
}
if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
if (input_type.getRank() != output_type.getRank()) {
return op.emitOpError("'input' and 'output' should have the same rank.");
}
// Check if input and output shapes are same
for (int i = 0; i < input_type.getRank(); i++) {
if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
return op.emitOpError(
"'input' and 'output' should have the same shape.");
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//

View File

@ -1749,6 +1749,28 @@ def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect, NoQuantizableResult]
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> {
let summary = "Parameterized Relu operator";
let description = [{
Parameterized Relu operator
x -> x >= 0 ? x : (alpha * x)
where alpha is a trainable tensor.
alpha should have one less rank than the input as it doesn't have the batch
dimension, and the other dimensions either should be the same size as input
or size 1, where it is broadcasted in the second case.
}];
let arguments = (
ins TensorOf<[F32, QUI8]>:$input,
TensorOf<[F32, QUI8]>:$alpha
);
let results = (outs TensorOf<[F32, QUI8]>:$output);
let verifier = [{ return Verify(*this); }];
}
def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
let summary = "Rank operator.";
let description = [{

View File

@ -1418,6 +1418,66 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x
// -----
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> {
// expected-error @+1 {{'input' and 'output' should have the same rank}}
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32>
return %0 : tensor<10x10x10xf32>
}
// -----
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
// expected-error @+1 {{'input' and 'output' should have the same shape}}
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
return %0 : tensor<1x2x3x5xf32>
}
// -----
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
// expected-error @+1 {{'alpha' should have one less rank than 'input'.}}
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
return %0 : tensor<7x3x2x14xf32>
}
// -----
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
// expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}}
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
return %0 : tensor<15x14x2x14xf32>
}
// -----
func @testPReluValidSameSize(%arg0: tensor<16x20x20x13xf32>, %arg1: tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32> {
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<16x20x20x13xf32>, tensor<20x20x13xf32>) -> tensor<16x20x20x13xf32>
return %0 : tensor<16x20x20x13xf32>
}
// -----
func @testPReluValidBroadcast(%arg0: tensor<19x7x12x14xf32>, %arg1: tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32> {
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<19x7x12x14xf32>, tensor<1x1x14xf32>) -> tensor<19x7x12x14xf32>
return %0 : tensor<19x7x12x14xf32>
}
// -----
func @testPReluValidFullBroadcast(%arg0: tensor<7x8x9x10xf32>, %arg1: tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32> {
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x8x9x10xf32>, tensor<1x1x1xf32>) -> tensor<7x8x9x10xf32>
return %0 : tensor<7x8x9x10xf32>
}
// -----
func @testPReluValidQuantized(%arg0: tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>, %arg1: tensor<1x1x16x!quant.uniform<u8<1:255>:f32, 0.004846:14>>) -> tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>> {
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>, tensor<1x1x16x!quant.uniform<u8<1:255>:f32, 0.004846:14>>) -> tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>
return %0 : tensor<1x96x96x16x!quant.uniform<u8:f32, 0.00784:128>>
}
// -----
func @testSlice(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>