Add some more HLO to TF patterns
Also import TF clip by value op. PiperOrigin-RevId: 303545182 Change-Id: I87b3c0f061ad454f3af09bf799d0986eaf039a07
This commit is contained in:
parent
95be0349fb
commit
57a13f07a0
@ -1143,6 +1143,29 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
let description = [{
|
||||
Given a tensor `t`, this operation returns a tensor of the same type and
|
||||
shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
|
||||
Any values less than `clip_value_min` are set to `clip_value_min`. Any values
|
||||
greater than `clip_value_max` are set to `clip_value_max`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||
let summary = "Converts two real numbers to a complex number.";
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -40,6 +40,20 @@ class LegalizeHloToTf : public FunctionPass<LegalizeHloToTf> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns whether the two values are guaranteed to be broadcastable to the
|
||||
// same shape, this broadcasts size 1 tensors up to any rank.
|
||||
// TODO(jpienaar): Move this to more general location.
|
||||
static bool AreBroadcastCompatible(Value x, Value y) {
|
||||
auto x_ranked = x.getType().dyn_cast<RankedTensorType>();
|
||||
auto y_ranked = y.getType().dyn_cast<RankedTensorType>();
|
||||
if (!x_ranked || !y_ranked) {
|
||||
return true;
|
||||
}
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
return OpTrait::util::getBroadcastedShape(x_ranked.getShape(),
|
||||
y_ranked.getShape(), resultShape);
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
|
||||
|
||||
/// Performs the lowering to XLA dialect.
|
||||
|
@ -20,14 +20,16 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>;
|
||||
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class DirectBinaryPat<Op FromOp, Op ToOp>
|
||||
: Pat<(FromOp $l, $r, $_), (ToOp $l, $r)>;
|
||||
// Check that two values can be broadcasted together
|
||||
// TODO(jpienaar): Move somewhere more general
|
||||
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
|
||||
"types must be broadcastable">;
|
||||
|
||||
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
|
||||
[HLO_DivOp, TF_DivOp],
|
||||
@ -37,24 +39,41 @@ foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
|
||||
[HLO_MulOp, TF_MulOp],
|
||||
[HLO_PowOp, TF_PowOp],
|
||||
[HLO_DivOp, TF_RealDivOp],
|
||||
[HLO_SubOp, TF_SubOp]] in
|
||||
def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
|
||||
[HLO_SubOp, TF_SubOp],
|
||||
[HLO_Atan2Op, TF_Atan2Op],
|
||||
[HLO_RemOp, TF_ModOp]] in
|
||||
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def LowerRightShiftSigned :
|
||||
Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(SignedIntTensor $r)]>;
|
||||
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
|
||||
[HLO_OrOp, TF_BitwiseOrOp],
|
||||
[HLO_XorOp, TF_BitwiseXorOp]] in
|
||||
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r)>;
|
||||
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
|
||||
[HLO_OrOp, TF_LogicalOrOp]] in
|
||||
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
foreach Mapping = [
|
||||
[HLO_AbsOp, TF_AbsOp],
|
||||
foreach Mapping = [[HLO_AbsOp, TF_AbsOp],
|
||||
[HLO_BitcastConvertOp, TF_BitcastOp],
|
||||
[HLO_CeilOp, TF_CeilOp],
|
||||
[HLO_CosOp, TF_CosOp],
|
||||
[HLO_ExpOp, TF_ExpOp],
|
||||
[HLO_Expm1Op, TF_Expm1Op],
|
||||
[HLO_FloorOp, TF_FloorOp],
|
||||
[HLO_ImagOp, TF_ImagOp],
|
||||
[HLO_IsFiniteOp, TF_IsFiniteOp],
|
||||
@ -65,8 +84,46 @@ foreach Mapping = [
|
||||
[HLO_RealOp, TF_RealOp],
|
||||
[HLO_RsqrtOp, TF_RsqrtOp],
|
||||
[HLO_SinOp, TF_SinOp],
|
||||
[HLO_SignOp, TF_SignOp],
|
||||
[HLO_SqrtOp, TF_SqrtOp],
|
||||
[HLO_TanhOp, TF_TanhOp],
|
||||
] in {
|
||||
def : Pat<(Mapping[0] $input), (Mapping[1] $input)>;
|
||||
}
|
||||
[HLO_TanhOp, TF_TanhOp]] in
|
||||
def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>;
|
||||
|
||||
def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>;
|
||||
|
||||
def : Pat<(HLO_BroadcastOp $arg, $shape),
|
||||
(TF_BroadcastToOp $arg, (TF_ConstOp $shape))>;
|
||||
def : Pat<(HLO_TransposeOp $arg, $permutation),
|
||||
(TF_TransposeOp $arg, (TF_ConstOp $permutation))>;
|
||||
def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(HLO_ClampOp $min, $arg, $max),
|
||||
(TF_MaximumOp (TF_MinimumOp $arg, $max), $min)>;
|
||||
def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Variadic op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(HLO_ConcatenateOp $inputs, $dim),
|
||||
(TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compare op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
|
||||
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
|
||||
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
|
||||
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
|
||||
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
Loading…
x
Reference in New Issue
Block a user