Adds quantized input/output types in tfl.split op.
PiperOrigin-RevId: 261761611
This commit is contained in:
parent
038d0cdabc
commit
2b8b865613
@ -2178,7 +2178,7 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> {
|
||||
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
|
||||
|
||||
let description = [{
|
||||
@ -2189,18 +2189,18 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
I32Tensor:$split_dim,
|
||||
TensorOf<[F32, I16, I32, I64]>:$value,
|
||||
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
|
||||
I32Attr:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> {
|
||||
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
|
||||
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
|
||||
|
||||
let description = [{
|
||||
@ -2210,14 +2210,14 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I16, I32, I64]>:$value,
|
||||
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$value,
|
||||
I32Tensor:$size_splits,
|
||||
I32Tensor:$split_dim,
|
||||
I32Attr:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>>:$outputs
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
@ -1052,4 +1052,18 @@ func @testRoundInvalidInputType(%arg: tensor<?xi32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{'tfl.round' op operand #0 must be tensor of 32-bit float values}}
|
||||
%0 = "tfl.round"(%arg) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitWithQuantizedTypes(%arg0 : tensor<i32>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
|
||||
%0 = "tfl.split"(%arg0, %arg1) {num_splits = 1 : i32} : (tensor<i32>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
|
||||
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
|
||||
%0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<i32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
|
||||
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user