Add op sanity checks to the following TFLite ops:

SpaceToDepthOp
SparseConstOp
SparseQConstOp
SparseToDenseOp
SplitOp
SplitVOp
SquaredDifferenceOp
SqueezeOp
StridedSliceOp
SubOp
SumOp
TanhOp
TileOp
TransposeConvOp
TransposeOp
UniqueOp
UnpackOp
WhereOp
WhileOp
YieldOp
ZerosLikeOp

PiperOrigin-RevId: 314280789
Change-Id: I030c487d9a6363b8f06d103786c9d42bfc7381b3
This commit is contained in:
Jaesung Chung 2020-06-02 00:34:47 -07:00 committed by TensorFlower Gardener
parent bc49458b14
commit e2aa757a55
4 changed files with 190 additions and 131 deletions

View File

@ -1966,9 +1966,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
} }
static LogicalResult Verify(TransposeOp op) { static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x().getType().cast<ShapedType>(); auto input_type = op.input().getType().cast<ShapedType>();
auto perm_type = op.perm().getType().cast<ShapedType>(); auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y().getType().cast<ShapedType>(); auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) { if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError( return op.emitOpError(

View File

@ -161,7 +161,6 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
Variadic<TensorOf<allowedOpTypes>>, Variadic<TensorOf<allowedOpTypes>>,
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>; TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
def TFL_Uint8 : UI<8>;
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>; def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
def TFL_BoolTensor : TFL_TensorOf<[I1]>; def TFL_BoolTensor : TFL_TensorOf<[I1]>;
@ -294,21 +293,33 @@ class TFL_OperandHasRankRange<int n, int x, int y> :
"getRank() <= " # y>]>>; "getRank() <= " # y>]>>;
def TFL_FloatNonNegative : AttrConstraint< def TFL_FloatNonNegative : AttrConstraint<
CPred<"!$_self.cast<FloatAttr>().getValue().isNegative()">, CPred<"$_self.isa<FloatAttr>() && "
"!$_self.cast<FloatAttr>().getValue().isNegative()">,
"whose value is non-negative">; "whose value is non-negative">;
def TFL_BoolTrue : AttrConstraint< def TFL_BoolTrue : AttrConstraint<
CPred<"$_self.cast<BoolAttr>().getValue()">, CPred<"$_self.isa<BoolAttr>() && $_self.cast<BoolAttr>().getValue()">,
"whose value is true">; "whose value is true">;
def TFL_BoolFalse : AttrConstraint< def TFL_BoolFalse : AttrConstraint<
CPred<"!$_self.cast<BoolAttr>().getValue()">, CPred<"$_self.isa<BoolAttr>() && !$_self.cast<BoolAttr>().getValue()">,
"whose value is false">; "whose value is false">;
class TFL_StringEqualsTo<string value> : AttrConstraint< class TFL_StringEqualsTo<string value> : AttrConstraint<
CPred<"$_self.cast<StringAttr>().getValue() == \"" # value # "\"">, CPred<"$_self.cast<StringAttr>().getValue() == \"" # value # "\"">,
"whose value equals to '" # value # "'">; "whose value equals to '" # value # "'">;
// Ensures the array attribute's size is within the given maximum size.
class TFL_ArrayMaxCount<int n> : AttrConstraint<
CPred<"$_self.isa<ArrayAttr>() && $_self.cast<ArrayAttr>().size() <= " # n>,
"whose size is at most " # n>;
// Ensures the given integer attribute has the given value.
class TFL_IntEqualsTo<int n> : AttrConstraint<
CPred<"$_self.isa<IntegerAttr>() && "
"$_self.cast<IntegerAttr>().getInt() == " # n>,
"whose value is " # n>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp // This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
TCOpResIsShapedTypePred<i, j>, TCOpResIsShapedTypePred<i, j>,
@ -472,7 +483,10 @@ an output element, this operation computes \\(y = |x|\\).
def TFL_AddOp : TFL_Op<"add", [ def TFL_AddOp : TFL_Op<"add", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> { ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
let summary = "Addition operator"; let summary = "Addition operator";
let description = [{ let description = [{
@ -540,8 +554,14 @@ retained with length 1.
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
} }
def TFL_TransposeConvOp: def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
TFL_Op<"transpose_conv", [NoSideEffect, TFL_GpuTargetOp]> { NoSideEffect,
TFL_OperandHasRank<0, 1>,
TFL_OperandHasRank<1, 4>,
TFL_OperandHasRank<2, 4>,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 2>>,
TFL_GpuTargetOp]> {
let summary = "Transpose convolution operator"; let summary = "Transpose convolution operator";
let description = [{ let description = [{
@ -549,16 +569,16 @@ def TFL_TransposeConvOp:
}]; }];
let arguments = (ins let arguments = (ins
TFL_1DTensorOf<[I32]>:$output_shape, TFL_I32Tensor:$output_shape,
TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$weights, TFL_TensorOf<[F32, QI8, QUI8]>:$weights,
TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$input, TFL_TensorOf<[F32, QI8, QUI8]>:$input,
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_TensorOfOrNone<[F32, QI32]>:$bias,
TFL_PaddingAttr:$padding, TFL_PaddingAttr:$padding,
I32Attr:$stride_h, Confined<I32Attr, [IntPositive]>:$stride_h,
I32Attr:$stride_w Confined<I32Attr, [IntPositive]>:$stride_w
); );
let results = (outs AnyTensor:$output); let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
let hasOptions = 1; let hasOptions = 1;
@ -600,7 +620,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, ins TFL_TensorOf<[F32, I32, I8, UI8, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$dim TFL_I32OrI64Tensor:$dim
); );
@ -630,7 +650,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, ins TFL_TensorOf<[F32, I32, I8, UI8, QI8, QUI8]>:$input,
TFL_I32OrI64Tensor:$dim TFL_I32OrI64Tensor:$dim
); );
@ -677,14 +697,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
let arguments = ( let arguments = (
ins TFL_VariadicTensorOf< ins TFL_VariadicTensorOf<
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$values, [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$values,
I32Attr:$axis, I32Attr:$axis,
TFL_AFAttr:$fused_activation_function TFL_AFAttr:$fused_activation_function
); );
let results = (outs let results = (outs
TFL_TensorOf< TFL_TensorOf<
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -748,7 +768,8 @@ def SparsityParameterAttr : StructAttr<"SparsityParameterAttr", TFL_Dialect, [
let storageType = [{ TFL::SparsityParameterAttr }]; let storageType = [{ TFL::SparsityParameterAttr }];
} }
def TFL_SparseConstOp : Op<TFL_Dialect, "pseudo_sparse_const", [NoSideEffect, def TFL_SparseConstOp : Op<TFL_Dialect, "pseudo_sparse_const", [
NoSideEffect,
FirstAttrDerivedResultType]> { FirstAttrDerivedResultType]> {
let summary = "Sparse constant pseudo op."; let summary = "Sparse constant pseudo op.";
@ -959,12 +980,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$params, TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$params,
TFL_I32OrI64Tensor:$indices TFL_I32OrI64Tensor:$indices
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$output TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$output
); );
} }
@ -983,12 +1004,12 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
let arguments = (ins let arguments = (ins
TFL_TensorOf<[I32]>:$indices, TFL_TensorOf<[I32]>:$indices,
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$updates, TFL_TensorOf<[F32, I8, I64, I32, UI8]>:$updates,
TFL_1DTensorOf<[I32]>:$shape TFL_1DTensorOf<[I32]>:$shape
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output TFL_TensorOf<[F32, I8, I64, I32, UI8]>:$output
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -1103,11 +1124,11 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$diagonal TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QUI8, QI8, TFL_Quint8]>:$diagonal
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$output TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QUI8, QI8, TFL_Quint8]>:$output
); );
let hasOptions = 0; let hasOptions = 0;
@ -1285,8 +1306,10 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
} }
def TFL_DivOp : TFL_Op<"div", [ def TFL_DivOp : TFL_Op<"div", [
// TODO(fengliuai): NoQuantizableResult is only correct for int8 // TODO(fengliuai): NoQuantizableResult is only correct for int8
// quantization. update to handle Uint8 quantization. // quantization. update to handle Uint8 quantization.
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
ResultsBroadcastableShape, ResultsBroadcastableShape,
NoSideEffect, NoSideEffect,
NoQuantizableResult, NoQuantizableResult,
@ -1299,10 +1322,10 @@ def TFL_DivOp : TFL_Op<"div", [
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, I32, QUI8]>:$lhs, ins TFL_TensorOf<[F32, I32, QUI8]>:$lhs,
TFL_TensorOf<[F32, I32, TFL_Uint8]>:$rhs, TFL_TensorOf<[F32, I32, QUI8]>:$rhs,
TFL_AFAttr:$fused_activation_function); TFL_AFAttr:$fused_activation_function);
let results = (outs TFL_TensorOf<[F32, I32, TFL_Uint8]>:$output); let results = (outs TFL_TensorOf<[F32, I32, QUI8]>:$output);
let builders = [TFL_FusedBroadcastableBinaryBuilder]; let builders = [TFL_FusedBroadcastableBinaryBuilder];
@ -1345,10 +1368,10 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
let arguments = (ins let arguments = (ins
TFL_TensorOf<[I32]>:$lookup, TFL_TensorOf<[I32]>:$lookup,
TFL_TensorOf<[F32, I8, TFL_Uint8]>:$value TFL_TensorOf<[F32, I8, UI8]>:$value
); );
let results = (outs TFL_TensorOf<[F32, I8, TFL_Uint8]>:$output); let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output);
} }
def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
@ -1364,8 +1387,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
let arguments = ( let arguments = (
ins ins
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$x, TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$x,
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$y TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$y
); );
let results = (outs TFL_BoolTensor:$output); let results = (outs TFL_BoolTensor:$output);
@ -1445,7 +1468,7 @@ def TFL_SqueezeOp: TFL_Op<"squeeze", [NoSideEffect,
Given a tensor `input`, this operation returns a tensor of the same type with Given a tensor `input`, this operation returns a tensor of the same type with
all dimensions of size 1 removed. If you don't want to remove all size 1 all dimensions of size 1 removed. If you don't want to remove all size 1
dimensions, you can remove specific size 1 dimensions by specifying dimensions, you can remove specific size 1 dimensions by specifying
`axis`. `squeeze_dims`.
For example: For example:
@ -1464,7 +1487,7 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
let arguments = (ins let arguments = (ins
AnyTensor:$input, AnyTensor:$input,
DefaultValuedAttr<I64ArrayAttr, "{}">:$squeeze_dims Confined<DefaultValuedAttr<I64ArrayAttr, "{}">, [TFL_ArrayMaxCount<8>]>:$squeeze_dims
); );
let results = (outs let results = (outs
@ -1899,13 +1922,13 @@ def TFL_MeanOp : TFL_Op<"mean", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input,
TFL_TensorOf<[I32, I64]>:$axis, TFL_TensorOf<[I32, I64]>:$axis,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$output); TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
@ -1998,7 +2021,11 @@ equivalent to setting:
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { def TFL_SumOp: TFL_Op<"sum", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect]> {
let summary = "Sum operator"; let summary = "Sum operator";
let description = [{ let description = [{
@ -2006,12 +2033,12 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
}]; }];
let arguments = (ins let arguments = (ins
AnyTensor:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_I32Tensor:$axes, TFL_I32Tensor:$axes,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
let results = (outs AnyTensor); let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
let hasOptions = 1; let hasOptions = 1;
let customOption = "ReducerOptions"; let customOption = "ReducerOptions";
@ -2113,12 +2140,13 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
let hasOptions = 0; let hasOptions = 0;
} }
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, def TFL_MulOp : TFL_Op<"mul", [
NoSideEffect, ResultsBroadcastableShape,
Commutative, NoSideEffect,
BinaryOpSameElementTypeConstraint, Commutative,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, BinaryOpSameElementTypeConstraint,
TFL_GpuTargetOp]> { TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
TFL_GpuTargetOp]> {
let summary = "Multiplication operator"; let summary = "Multiplication operator";
let description = [{ let description = [{
@ -2611,7 +2639,7 @@ def TFL_SelectOp : TFL_Op<"select", [
NoSideEffect, NoSideEffect,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type", PredOpTrait<"operands and result have same element type",
TCresVTEtIsSameAsOp<0, 1>>]> { TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "Select operator"; let summary = "Select operator";
let description = [{ let description = [{
@ -2647,7 +2675,7 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>, TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type", PredOpTrait<"operands and result have same element type",
TCresVTEtIsSameAsOp<0, 1>>]> { TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "SelectV2 operator"; let summary = "SelectV2 operator";
let description = [{ let description = [{
@ -2760,7 +2788,11 @@ def TFL_SquareOp: TFL_Op<"square", [
let hasFolder = 1; let hasFolder = 1;
} }
def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { def TFL_SubOp : TFL_Op<"sub", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
NoSideEffect]> {
let summary = "Subtraction operator"; let summary = "Subtraction operator";
let description = [{ let description = [{
@ -2768,11 +2800,11 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
ins AnyTensor:$lhs, ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
AnyTensor:$rhs, TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function); TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output); let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let hasFolder = 1; let hasFolder = 1;
@ -2788,6 +2820,8 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
// TODO(jpienaar): Expand the kernel implementation to support all types besides // TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32. // I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
SameOperandsAndResultElementType,
ResultsBroadcastableShape, ResultsBroadcastableShape,
NoSideEffect, NoSideEffect,
NoQuantizableResult, NoQuantizableResult,
@ -2799,10 +2833,10 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
}]; }];
let arguments = ( let arguments = (
ins AnyTensor:$lhs, ins TFL_TensorOf<[F32, I32]>:$lhs,
AnyTensor:$rhs); TFL_TensorOf<[F32, I32]>:$rhs);
let results = (outs AnyTensor:$output); let results = (outs TFL_TensorOf<[F32, I32]>:$output);
let builders = [TFL_BroadcastableBinaryBuilder]; let builders = [TFL_BroadcastableBinaryBuilder];
@ -2814,6 +2848,8 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
def TFL_TanhOp: TFL_Op<"tanh", [ def TFL_TanhOp: TFL_Op<"tanh", [
NoSideEffect, NoSideEffect,
SameOperandsAndResultShape, SameOperandsAndResultShape,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
// central_value = min_value / 2 + (max_value - 1) / 2 + 1 // central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value // zero_point = central_value
// scale = 1. / (central_value - min_value) // scale = 1. / (central_value - min_value)
@ -2826,9 +2862,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [
Computes element-wise Hyperbolic tangent of input Computes element-wise Hyperbolic tangent of input
}]; }];
let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x); let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input);
let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$output);
// This builder doesn't work with quantized type, so it can only be used by // This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the // non-quantization tablegen patterns. Currently, it is used by the
@ -2842,9 +2878,11 @@ def TFL_TanhOp: TFL_Op<"tanh", [
]; ];
} }
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale, def TFL_TileOp: TFL_Op<"tile", [
PredOpTrait<"resultant element type needs to match first operand type", NoSideEffect,
TFL_TCresVTEtIsSameAsOp<0,0>>]> { SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "Tile operator."; let summary = "Tile operator.";
let description = [{ let description = [{
Constructs a tensor by tiling a given tensor. Constructs a tensor by tiling a given tensor.
@ -2857,11 +2895,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$input, TFL_TensorOf<[F32, I1, I32, I64, UI8, QUI8, TFL_Str]>:$input,
TFL_I32OrI64Tensor:$multiples); TFL_I32OrI64Tensor:$multiples);
let results = (outs let results = (outs
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$output); TFL_TensorOf<[F32, I1, I32, I64, UI8, QUI8, TFL_Str]>:$output);
let hasOptions = 0; let hasOptions = 0;
} }
@ -2869,9 +2907,13 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
// TODO(jpienaar): Maybe make it accept any single element tensor as `k`. // TODO(jpienaar): Maybe make it accept any single element tensor as `k`.
// TODO(jpienaar): Check that input has one or more dimensions. // TODO(jpienaar): Check that input has one or more dimensions.
// TODO(jpienaar): Check that k is less or equal the internal dimension // TODO(jpienaar): Check that k is less or equal the internal dimension
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, def TFL_TopKV2Op: TFL_Op<"topk_v2", [
NoSideEffect,
TFL_OperandHasRankAtLeast<0, 1>,
TFL_OperandHasRank<1, 0>,
PredOpTrait<"result and input element type match", PredOpTrait<"result and input element type match",
TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> { TFL_TCresVTEtIsSameAsOp<0,0>>,
SameOperandsAndResultsScale]> {
let summary = "TopK operator"; let summary = "TopK operator";
let description = [{ let description = [{
@ -2881,11 +2923,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input, TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$input,
TFL_I32Tensor:$k); TFL_I32Tensor:$k);
let results = (outs let results = (outs
TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values, TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$values,
TFL_I32Tensor:$indices); TFL_I32Tensor:$indices);
let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, "
@ -2895,29 +2937,27 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
let hasOptions = 1; let hasOptions = 1;
} }
def TFL_TransposeOp : TFL_Op<"transpose", def TFL_TransposeOp : TFL_Op<"transpose", [
[NoSideEffect, NoSideEffect,
TFL_OperandHasRank<1,1>, TFL_OperandHasRankAtMost<0, 5>,
// TODO(jpienaar): these are only true dynamically, change so that it works TFL_OperandHasRank<1, 1>,
// with unknowns. PredOpTrait<"input and output must have same element type",
// TFL_OperandRankEquals1DimOfOperand<0, 1>, TFL_TCresVTEtIsSameAsOp<0, 0>>,
PredOpTrait<"input and output must have same element type", SameOperandsAndResultsScale,
TCresVTEtIsSameAsOp<0, 0>>, TFL_GpuTargetOp]> {
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Transpose operator"; let summary = "Transpose operator";
let description = [{ let description = [{
Returns the Transpose of x Returns the Transpose of x
}]; }];
let arguments = ( let arguments = (ins
ins AnyTensor:$x, TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input,
TFL_TensorOf<[I32]>:$perm TFL_TensorOf<[I32]>:$perm
); );
let results = (outs let results = (outs
AnyTensor:$y TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -2925,7 +2965,10 @@ def TFL_TransposeOp : TFL_Op<"transpose",
let hasFolder = 1; let hasFolder = 1;
} }
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> { def TFL_UnpackOp : TFL_Op<"unpack", [
NoSideEffect,
SameOperandsAndResultElementType,
SameOperandsAndResultsScale]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors"; let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{ let description = [{
@ -2946,14 +2989,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]>
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input, TFL_TensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$input,
I32Attr:$num, I32Attr:$num,
I32Attr:$axis I32Attr:$axis
); );
let results = (outs let results = (outs
TFL_VariadicTensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$outputs TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -2961,16 +3004,19 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]>
let hasOptions = 1; let hasOptions = 1;
} }
def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [NoSideEffect]> { def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect]> {
let summary = "ZerosLike operator"; let summary = "ZerosLike operator";
let description = [{ let description = [{
Returns a tensor of zeros with the same shape and type as the input tensor. Returns a tensor of zeros with the same shape and type as the input tensor.
}]; }];
let arguments = (ins AnyTensor:$input); let arguments = (ins TFL_TensorOf<[I64, I32, F32]>:$input);
let results = (outs AnyTensor:$output); let results = (outs TFL_TensorOf<[I64, I32, F32]>:$output);
let hasOptions = 1; let hasOptions = 1;
} }
@ -3006,7 +3052,7 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
SameOperandsAndResultsScale, SameOperandsAndResultsScale,
TFL_OperandHasRankRange<0, 3, 4>, TFL_OperandHasRankRange<0, 3, 4>,
PredOpTrait<"input and output must have same element type", PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>> TFL_TCresVTEtIsSameAsOp<0, 0>>
]> { ]> {
let summary = "SpaceToBatchNd operator"; let summary = "SpaceToBatchNd operator";
@ -3029,7 +3075,8 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
NoSideEffect, NoSideEffect,
SameOperandsAndResultsScale, SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type", PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>, TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankAtMost<0, 4>,
TFL_GpuTargetOp TFL_GpuTargetOp
]> { ]> {
let summary = "SpaceToDepth operator"; let summary = "SpaceToDepth operator";
@ -3042,12 +3089,12 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
I32Attr:$block_size Confined<I32Attr, [IntPositive]>:$block_size
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -3072,12 +3119,12 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, TFL_Uint8, UI8, QI8, QUI8]>:$input, TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, UI8, QI8, QUI8]>:$input,
Confined<I32Attr, [IntPositive]>:$block_size Confined<I32Attr, [IntPositive]>:$block_size
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, TFL_Uint8, UI8, QI8, QUI8]>:$output TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, UI8, QI8, QUI8]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -3097,12 +3144,12 @@ def TFL_SplitOp : TFL_Op<"split", [
let arguments = (ins let arguments = (ins
TFL_TensorOf<[I32]>:$split_dim, TFL_TensorOf<[I32]>:$split_dim,
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, TFL_TensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$value,
Confined<I32Attr, [IntPositive]>:$num_splits Confined<I32Attr, [IntPositive]>:$num_splits
); );
let results = (outs let results = (outs
TFL_VariadicTensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$outputs TFL_VariadicTensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$outputs
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -3120,14 +3167,14 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, TFL_TensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$value,
TFL_1DTensorOf<[I32], [I32]>:$size_splits, TFL_1DTensorOf<[I32], [I32]>:$size_splits,
TFL_0DTensorOf<[I32], [I32]>:$split_dim, TFL_0DTensorOf<[I32], [I32]>:$split_dim,
Confined<I32Attr, [IntPositive]>:$num_splits Confined<I32Attr, [IntPositive]>:$num_splits
); );
let results = (outs let results = (outs
TFL_VariadicTensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$outputs TFL_VariadicTensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$outputs
); );
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
@ -3189,7 +3236,15 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
let hasOptions = 1; let hasOptions = 1;
} }
def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [NoSideEffect]> { def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [
NoSideEffect,
PredOpTrait<"sparse_values and dense must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 2>>,
PredOpTrait<"default_value and dense must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 3>>,
TFL_OperandHasRankAtMost<0, 2>,
TFL_OperandHasRankAtMost<1, 1>,
TFL_OperandHasRankAtMost<2, 1>]> {
let summary = "Converts a sparse representation into a dense tensor."; let summary = "Converts a sparse representation into a dense tensor.";
let description = [{ let description = [{
@ -3217,21 +3272,24 @@ are checked during execution.
let arguments = (ins let arguments = (ins
TFL_I32OrI64Tensor:$sparse_indices, TFL_I32OrI64Tensor:$sparse_indices,
TFL_I32OrI64Tensor:$output_shape, TFL_I32OrI64Tensor:$output_shape,
TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$sparse_values, TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$sparse_values,
TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$default_value TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$default_value
); );
let results = (outs let results = (outs
TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$dense TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$dense
); );
} }
def TFL_StridedSliceOp: TFL_Op<"strided_slice", def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
[
NoSideEffect, NoSideEffect,
PredOpTrait<"input and output must have same element type", PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>, TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultsScale, SameOperandsAndResultsScale,
TFL_OperandHasRankAtMost<0, 5>,
TFL_OperandHasRank<1, 1>,
TFL_OperandHasRank<2, 1>,
TFL_OperandHasRank<3, 1>,
TFL_GpuTargetOp TFL_GpuTargetOp
]> { ]> {
let summary = "StridedSlice Op"; let summary = "StridedSlice Op";
@ -3241,20 +3299,20 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$input, TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
TFL_TensorOf<[I32]>:$begin, TFL_I32Tensor:$begin,
TFL_TensorOf<[I32]>:$end, TFL_I32Tensor:$end,
TFL_TensorOf<[I32]>:$strides, TFL_I32Tensor:$strides,
I32Attr:$begin_mask, I32Attr:$begin_mask,
I32Attr:$end_mask, I32Attr:$end_mask,
I32Attr:$ellipsis_mask, Confined<I32Attr, [TFL_IntEqualsTo<0>]>:$ellipsis_mask,
I32Attr:$new_axis_mask, Confined<I32Attr, [TFL_IntEqualsTo<0>]>:$new_axis_mask,
I32Attr:$shrink_axis_mask I32Attr:$shrink_axis_mask
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$output TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -3269,17 +3327,16 @@ def TFL_CastOp : TFL_Op<"cast", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
); );
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$output); let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types // TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors. // from the TfLiteTensors.
let hasOptions = 0; let hasOptions = 0;
} }
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> { NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> {
let summary = "MirrorPad Operator. Pads a tensor with mirrored values."; let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
@ -3315,24 +3372,25 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
let hasOptions = 1; let hasOptions = 1;
} }
def TFL_UniqueOp: TFL_Op<"unique", [NoSideEffect]> { def TFL_UniqueOp: TFL_Op<"unique", [
TFL_OperandHasRank<0, 1>,
NoSideEffect]> {
let summary = "Unique Op."; let summary = "Unique Op.";
let description = [{ let description = [{
This operation returns a tensor `y` containing all of the unique elements of `x` This operation returns a tensor `output` containing all of the unique elements
sorted in the same order that they occur in `x`. This operation also returns a of `input` sorted in the same order that they occur in `input`. This operation
tensor `idx` the same size as `x` that contains the index of each value of `x` also returns a tensor `idx` the same size as `x` that contains the index of each
in the unique output `y`. In other words: value of `input` in the unique output `output`. In other words:
}]; }];
let arguments = (ins let arguments = (ins
// TODO: add uint8 support after quantize support. TFL_TensorOf<[I8, QI8, UI8, QUI8, I16, QI16, I32, I64, F32]>:$input
TFL_TensorOf<[I8, I16, I32, I64, F32]>:$input
); );
let results = (outs let results = (outs
TFL_TensorOf<[I8, I16, I32, I64, F32]>:$output, TFL_TensorOf<[I8, QI8, UI8, QUI8, I16, QI16, I32, I64, F32]>:$output,
TFL_TensorOf<[I32, I64]>:$idx TFL_I32OrI64Tensor:$idx
); );
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
@ -3432,7 +3490,7 @@ def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [
ElementsAttr:$compressed_data ElementsAttr:$compressed_data
); );
let results = (outs AnyTensor:$output); let results = (outs TFL_TensorOf<[QUI8, QI8, QI16, QUI16, TFL_Quint8]>:$output);
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &, OperationState &state, TypeAttr qtype, " "OpBuilder &, OperationState &state, TypeAttr qtype, "
@ -4076,7 +4134,7 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
def TFL_SVDFOp : def TFL_SVDFOp :
TFL_Op<"svdf", [ TFL_Op<"svdf", [
PredOpTrait<"the input and result tensor elemental types must be same", PredOpTrait<"the input and result tensor elemental types must be same",
TCresVTEtIsSameAsOp<0, 0>>, TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_StatefulOp]> { TFL_StatefulOp]> {
let summary = "Single value decomposition filter operator"; let summary = "Single value decomposition filter operator";
@ -4150,8 +4208,8 @@ def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
} }
def TFL_WhileOp : Op<TFL_Dialect, "while", [ def TFL_WhileOp : Op<TFL_Dialect, "while", [
DeclareOpInterfaceMethods<LoopLikeOpInterface>, DeclareOpInterfaceMethods<LoopLikeOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> { SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = [{While loop}]; let summary = [{While loop}];
let description = [{ let description = [{

View File

@ -184,7 +184,7 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell); auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
auto hidden_state = builder->create<MulOp>( auto hidden_state = builder->create<MulOp>(
loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af); loc, int16, new_cell_tanh.output(), output_gate->getResult(0), none_af);
auto act = builder->create<FullyConnectedOp>( auto act = builder->create<FullyConnectedOp>(
loc, int8, hidden_state.output(), lstm.projection_weights(), loc, int8, hidden_state.output(), lstm.projection_weights(),
lstm.projection_bias(), none_af, fc_format, keep_dims); lstm.projection_bias(), none_af, fc_format, keep_dims);

View File

@ -663,8 +663,8 @@ struct ConvertTrivialTransposeOpToReshapeOp
LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op, LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto input_type = transpose_op.x().getType().cast<ShapedType>(); auto input_type = transpose_op.input().getType().cast<ShapedType>();
auto output_type = transpose_op.y().getType().cast<ShapedType>(); auto output_type = transpose_op.output().getType().cast<ShapedType>();
// It's possible to know if the transformation is safe only if the input // It's possible to know if the transformation is safe only if the input
// & output shapes are fully known and permutation is a constant. // & output shapes are fully known and permutation is a constant.
if (!input_type.hasStaticShape() || !output_type.hasStaticShape()) if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
@ -713,7 +713,8 @@ struct ConvertTrivialTransposeOpToReshapeOp
auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr); auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
rewriter.replaceOpWithNewOp<TFL::ReshapeOp>( rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
transpose_op, transpose_op.y().getType(), transpose_op.x(), new_shape); transpose_op, transpose_op.output().getType(), transpose_op.input(),
new_shape);
return success(); return success();
} }