Adds quantized input/output types in tfl.split op.

PiperOrigin-RevId: 261761611
This commit is contained in:
A. Unique TensorFlower 2019-08-05 13:52:23 -07:00 committed by TensorFlower Gardener
parent 038d0cdabc
commit 2b8b865613
2 changed files with 20 additions and 6 deletions

View File

@ -2178,7 +2178,7 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
let hasOptions = 1;
}
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> {
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
let description = [{
@ -2189,18 +2189,18 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> {
let arguments = (ins
I32Tensor:$split_dim,
TensorOf<[F32, I16, I32, I64]>:$value,
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
I32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
);
let hasOptions = 1;
}
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> {
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
let description = [{
@ -2210,14 +2210,14 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> {
}];
let arguments = (ins
TensorOf<[F32, I16, I32, I64]>:$value,
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
I32Tensor:$size_splits,
I32Tensor:$split_dim,
I32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
);
let hasOptions = 1;

View File

@ -1052,4 +1052,18 @@ func @testRoundInvalidInputType(%arg: tensor<?xi32>) -> tensor<?xi32> {
// expected-error @+1 {{'tfl.round' op operand #0 must be tensor of 32-bit float values}}
%0 = "tfl.round"(%arg) : (tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testSplitWithQuantizedTypes(%arg0 : tensor<i32>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.split"(%arg0, %arg1) {num_splits = 1 : i32} : (tensor<i32>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
// -----
func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<i32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}