Update shape constraints.
Also refactored the utility function that checks broadcastable inputs to support ternary inputs for SelectV2 op. PiperOrigin-RevId: 315035592 Change-Id: I6ac8f20dbe99f27ca83a361e9b55024b8dc71bcd
This commit is contained in:
parent
c0e6ce2295
commit
e879c4f8bd
tensorflow/compiler/mlir/lite
@ -46,28 +46,68 @@ namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
// Returns true when the given two types have the same shape or broadcastable
|
||||
// shape within the given rank. If any given shapes are non-static, this method
|
||||
// returns true.
|
||||
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
|
||||
int max_bcast_rank) {
|
||||
// Ignore shape checking on the non-static shapes for model compatibility.
|
||||
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
|
||||
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
|
||||
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
|
||||
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
|
||||
// Returns true when the given operand arguments have the same shape or
|
||||
// broadcastable shape within the given rank. If any given shapes are
|
||||
// non-static and maximum rank is within the given rank, this method returns
|
||||
// true.
|
||||
bool IsOperandsHaveSameShapesOrBroadcastableShape(Operation *op,
|
||||
ArrayRef<unsigned> indices,
|
||||
int max_bcast_rank) {
|
||||
if (indices.empty()) return true;
|
||||
|
||||
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
|
||||
return true;
|
||||
// First, it checks there are any inputs that has unknown rank.
|
||||
bool has_unknown_shape_input = false;
|
||||
bool has_same_shape = true;
|
||||
bool reach_first_known_shape = false;
|
||||
int64_t max_rank = -1;
|
||||
|
||||
ArrayRef<int64_t> pivot_shape;
|
||||
SmallVector<int64_t, 4> current_shape;
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
|
||||
rhs_shaped_type.getShape(),
|
||||
result_shape)) {
|
||||
return false;
|
||||
|
||||
for (unsigned index : indices) {
|
||||
ShapedType shaped_type =
|
||||
op->getOperand(index).getType().dyn_cast<ShapedType>();
|
||||
if (!shaped_type || !shaped_type.hasRank()) {
|
||||
// Marks that we have an unknown rank input.
|
||||
has_unknown_shape_input = true;
|
||||
continue;
|
||||
}
|
||||
max_rank = std::max(max_rank, shaped_type.getRank());
|
||||
if (!shaped_type.hasStaticShape()) {
|
||||
// Marks that we have an unknown shape input.
|
||||
has_unknown_shape_input = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> shape = shaped_type.getShape();
|
||||
if (!reach_first_known_shape) {
|
||||
pivot_shape = shape;
|
||||
current_shape.assign(shape.begin(), shape.end());
|
||||
reach_first_known_shape = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!pivot_shape.equals(shape)) {
|
||||
has_same_shape = false;
|
||||
}
|
||||
// Checks if all the inputs are broadcastable since they have not all the
|
||||
// same shapes.
|
||||
if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
|
||||
result_shape)) {
|
||||
return false;
|
||||
}
|
||||
current_shape = result_shape;
|
||||
}
|
||||
return lhs_shaped_type.getRank() <= max_bcast_rank &&
|
||||
rhs_shaped_type.getRank() <= max_bcast_rank;
|
||||
|
||||
// It will treat the unknown shape inputs as acceptable inputs for model
|
||||
// compatibility unless there is an known rank that is bigger than the allowed
|
||||
// broadcast maximum rank.
|
||||
if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
|
||||
|
||||
// If all the shape is known and same, CPU kernels are able to handle inputs
|
||||
// regardless of dimension size.
|
||||
return has_same_shape || max_rank <= max_bcast_rank;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -123,14 +123,13 @@ class TFL_RuntimePredOpTrait<string desc, Pred pred> :
|
||||
string tflRuntimeDescription = desc;
|
||||
}
|
||||
|
||||
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
|
||||
int i, int j, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
|
||||
" have the same shape or broadcastable shapes within the rank " #
|
||||
max_bcast_rank,
|
||||
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
|
||||
").getType(), " # max_bcast_rank # ")">>;
|
||||
class TFL_OperandsHaveSameShapesOrBroadcastableShape<
|
||||
list<int> indices, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
|
||||
"broadcastable shapes within the rank " # max_bcast_rank,
|
||||
CPred<"TFL::IsOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
|
||||
"}), " # max_bcast_rank # ")">>;
|
||||
|
||||
// These additional types/type constraints here are used to decouple the ops
|
||||
// from runtime support for the ops. Prefer to use these types when defining
|
||||
@ -463,6 +462,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TFL_AbsOp : TFL_Op<"abs", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -482,7 +482,7 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
@ -669,7 +669,10 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
}]>;
|
||||
}
|
||||
|
||||
def TFL_CeilOp: TFL_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_CeilOp: TFL_Op<"ceil", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Ceil operator";
|
||||
|
||||
let description = [{
|
||||
@ -818,6 +821,7 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -1021,7 +1025,7 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
|
||||
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less_equal operator";
|
||||
@ -1082,7 +1086,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
}
|
||||
|
||||
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
@ -1150,12 +1154,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
|
||||
);
|
||||
|
||||
let hasOptions = 0;
|
||||
@ -1273,7 +1277,7 @@ larger than 0.
|
||||
}
|
||||
|
||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
ResultsBroadcastableShape,
|
||||
Commutative,
|
||||
@ -1309,7 +1313,7 @@ def TFL_DivOp : TFL_Op<"div", [
|
||||
// TODO(fengliuai): NoQuantizableResult is only correct for int8
|
||||
// quantization. update to handle Uint8 quantization.
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
@ -1338,7 +1342,10 @@ def TFL_DivOp : TFL_Op<"div", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_EluOp: TFL_Op<"elu", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Exponential Linear Unit operator";
|
||||
let description = [{
|
||||
Computes the exponential linear
|
||||
@ -1374,10 +1381,11 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
||||
let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output);
|
||||
}
|
||||
|
||||
def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
|
||||
def TFL_EqualOp: TFL_Op<"equal", [
|
||||
Commutative,
|
||||
NoQuantizableResult,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
|
||||
let summary = "Equal operator";
|
||||
|
||||
@ -1516,7 +1524,10 @@ def TFL_FillOp: TFL_Op<"fill", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_FloorOp: TFL_Op<"floor", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Floor operator";
|
||||
|
||||
let description = [{
|
||||
@ -1534,7 +1545,7 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"lhs and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> {
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> {
|
||||
let summary = "Floor div operator";
|
||||
|
||||
let description = [{
|
||||
@ -1559,7 +1570,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"lhs and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> {
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> {
|
||||
let summary = "Division reminder";
|
||||
|
||||
let description = [{
|
||||
@ -1578,7 +1589,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
@ -1670,7 +1681,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
||||
def TFL_LessOp : TFL_Op<"less", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
@ -1710,7 +1721,10 @@ def TFL_LogicalAndOp : TFL_Op<"logical_and", [NoSideEffect]> {
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
}
|
||||
|
||||
def TFL_LogicalNotOp : TFL_Op<"logical_not", [NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_LogicalNotOp : TFL_Op<"logical_not", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Logical NOT operator";
|
||||
|
||||
let description = [{
|
||||
@ -1794,6 +1808,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -1884,6 +1899,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2118,6 +2134,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
|
||||
def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2145,7 +2162,7 @@ def TFL_MulOp : TFL_Op<"mul", [
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Multiplication operator";
|
||||
|
||||
@ -2171,7 +2188,10 @@ def TFL_MulOp : TFL_Op<"mul", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_NegOp: TFL_Op<"neg", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Negation operator";
|
||||
|
||||
let description = [{
|
||||
@ -2333,10 +2353,12 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
def TFL_PowOp : TFL_Op<"pow", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Power operator";
|
||||
|
||||
let description = [{
|
||||
@ -2360,7 +2382,7 @@ def TFL_PReluOp : TFL_Op<"prelu", [
|
||||
NoSideEffect,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"input and output must have the same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
@ -2671,8 +2693,9 @@ def TFL_SelectOp : TFL_Op<"select", [
|
||||
}
|
||||
|
||||
def TFL_SelectV2Op : TFL_Op<"select_v2", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
@ -2705,6 +2728,7 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [
|
||||
|
||||
def TFL_SinOp: TFL_Op<"sin", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2752,6 +2776,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
|
||||
def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2770,6 +2795,7 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
|
||||
def TFL_SquareOp: TFL_Op<"square", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2791,7 +2817,7 @@ def TFL_SquareOp: TFL_Op<"square", [
|
||||
def TFL_SubOp : TFL_Op<"sub", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Subtraction operator";
|
||||
|
||||
@ -2820,7 +2846,7 @@ def TFL_SubOp : TFL_Op<"sub", [
|
||||
// TODO(jpienaar): Expand the kernel implementation to support all types besides
|
||||
// I32 and F32.
|
||||
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
SameOperandsAndResultElementType,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
@ -3007,6 +3033,8 @@ def TFL_UnpackOp : TFL_Op<"unpack", [
|
||||
def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultType,
|
||||
SameOperandsAndResultShape,
|
||||
NoSideEffect]> {
|
||||
let summary = "ZerosLike operator";
|
||||
|
||||
@ -3319,7 +3347,9 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
}
|
||||
|
||||
def TFL_CastOp : TFL_Op<"cast", [
|
||||
NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Cast operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -1529,3 +1529,22 @@ func @matmul_batchv2_unknown_dim(%arg0: tensor<?x10x15xf32>, %arg1: tensor<15x17
|
||||
// CHECK-LABEL: matmul_batchv2_unknown_dim
|
||||
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<?x10x15xf32>, tensor<15x17xf32>) -> tensor<?x10x17xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tensor<1x1x1x1x1x4xf32>, %arg2 : tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
|
||||
return %0 : tensor<1x1x1x2x3x4xf32>
|
||||
// CHECK-LABEL: select_v2_with_6d_broadcasting
|
||||
// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<1x1x1x1x8x16xf32> {
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x1x1x1x8x16xf32>, tensor<8x16xf32>) -> tensor<1x1x1x1x8x16xf32>
|
||||
return %0 : tensor<1x1x1x1x8x16xf32>
|
||||
|
||||
// CHECK-LABEL: maximum_with_6d_broadcasting
|
||||
// CHECK: "tf.Maximum"(%arg0, %arg1)
|
||||
}
|
||||
|
@ -794,6 +794,41 @@ func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testSelectV2
|
||||
func @testSelectV2(%cond : tensor<*xi1>, %arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testSelectV2WithHighDimInputs
|
||||
func @testSelectV2WithHighDimInputs(%cond : tensor<1x2x3x4x5x6xi1>, %arg0 : tensor<1x2x3x4x5x6xf32>, %arg1 : tensor<1x2x3x4x5x6xf32>) -> tensor<1x2x3x4x5x6xf32> {
|
||||
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<1x2x3x4x5x6xi1>, tensor<1x2x3x4x5x6xf32>, tensor<1x2x3x4x5x6xf32>) -> tensor<1x2x3x4x5x6xf32>
|
||||
return %0 : tensor<1x2x3x4x5x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testSelectV2With4DBroadcasting
|
||||
func @testSelectV2With4DBroadcasting(%cond : tensor<1x1x3x1xi1>, %arg0 : tensor<1x1x1x4xf32>, %arg1 : tensor<1x2x1x1xf32>) -> tensor<1x2x3x4xf32> {
|
||||
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<1x1x3x1xi1>, tensor<1x1x1x4xf32>, tensor<1x2x1x1xf32>) -> tensor<1x2x3x4xf32>
|
||||
return %0 : tensor<1x2x3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSelectV2WithWrongBroadcastableArguments(%cond : tensor<3x4xi1>, %arg0 : tensor<2x3x4xf32>, %arg1 : tensor<4x3xf32>) -> tensor<2x3x4xf32> {
|
||||
// expected-error @+1 {{'tfl.select_v2' op operands don't have broadcast-compatible shapes}}
|
||||
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<3x4xi1>, tensor<2x3x4xf32>, tensor<4x3xf32>) -> tensor<2x3x4xf32>
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: topk
|
||||
func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<8xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
|
Loading…
Reference in New Issue
Block a user