Add some more HLO to TF patterns

Also import TF clip by value op.

PiperOrigin-RevId: 303545182
Change-Id: I87b3c0f061ad454f3af09bf799d0986eaf039a07
This commit is contained in:
Jacques Pienaar 2020-03-28 17:11:53 -07:00 committed by TensorFlower Gardener
parent 95be0349fb
commit 57a13f07a0
4 changed files with 1338 additions and 313 deletions

View File

@ -1143,6 +1143,29 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Clips tensor values to a specified min and max.";
let description = [{
Given a tensor `t`, this operation returns a tensor of the same type and
shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
Any values less than `clip_value_min` are set to `clip_value_min`. Any values
greater than `clip_value_max` are set to `clip_value_max`.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
let summary = "Converts two real numbers to a complex number.";

File diff suppressed because it is too large Load Diff

View File

@ -40,6 +40,20 @@ class LegalizeHloToTf : public FunctionPass<LegalizeHloToTf> {
void runOnFunction() override;
};
// Returns whether the two values are guaranteed to be broadcastable to the
// same shape, this broadcasts size 1 tensors up to any rank.
// TODO(jpienaar): Move this to more general location.
static bool AreBroadcastCompatible(Value x, Value y) {
auto x_ranked = x.getType().dyn_cast<RankedTensorType>();
auto y_ranked = y.getType().dyn_cast<RankedTensorType>();
if (!x_ranked || !y_ranked) {
return true;
}
SmallVector<int64_t, 4> resultShape;
return OpTrait::util::getBroadcastedShape(x_ranked.getShape(),
y_ranked.getShape(), resultShape);
}
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
/// Performs the lowering to XLA dialect.

View File

@ -20,14 +20,16 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>;
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//
class DirectBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp $l, $r, $_), (ToOp $l, $r)>;
// Check that two values can be broadcasted together
// TODO(jpienaar): Move somewhere more general
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
"types must be broadcastable">;
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
[HLO_DivOp, TF_DivOp],
@ -37,24 +39,41 @@ foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
[HLO_MulOp, TF_MulOp],
[HLO_PowOp, TF_PowOp],
[HLO_DivOp, TF_RealDivOp],
[HLO_SubOp, TF_SubOp]] in
def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
[HLO_SubOp, TF_SubOp],
[HLO_Atan2Op, TF_Atan2Op],
[HLO_RemOp, TF_ModOp]] in
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def LowerRightShiftSigned :
Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(SignedIntTensor $r)]>;
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
[HLO_OrOp, TF_BitwiseOrOp],
[HLO_XorOp, TF_BitwiseXorOp]] in
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r)>;
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
[HLO_OrOp, TF_LogicalOrOp]] in
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
//===----------------------------------------------------------------------===//
// Unary op patterns.
//===----------------------------------------------------------------------===//
foreach Mapping = [
[HLO_AbsOp, TF_AbsOp],
foreach Mapping = [[HLO_AbsOp, TF_AbsOp],
[HLO_BitcastConvertOp, TF_BitcastOp],
[HLO_CeilOp, TF_CeilOp],
[HLO_CosOp, TF_CosOp],
[HLO_ExpOp, TF_ExpOp],
[HLO_Expm1Op, TF_Expm1Op],
[HLO_FloorOp, TF_FloorOp],
[HLO_ImagOp, TF_ImagOp],
[HLO_IsFiniteOp, TF_IsFiniteOp],
@ -65,8 +84,46 @@ foreach Mapping = [
[HLO_RealOp, TF_RealOp],
[HLO_RsqrtOp, TF_RsqrtOp],
[HLO_SinOp, TF_SinOp],
[HLO_SignOp, TF_SignOp],
[HLO_SqrtOp, TF_SqrtOp],
[HLO_TanhOp, TF_TanhOp],
] in {
def : Pat<(Mapping[0] $input), (Mapping[1] $input)>;
}
[HLO_TanhOp, TF_TanhOp]] in
def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>;
def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>;
def : Pat<(HLO_BroadcastOp $arg, $shape),
(TF_BroadcastToOp $arg, (TF_ConstOp $shape))>;
def : Pat<(HLO_TransposeOp $arg, $permutation),
(TF_TransposeOp $arg, (TF_ConstOp $permutation))>;
def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>;
//===----------------------------------------------------------------------===//
// Ternary op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ClampOp $min, $arg, $max),
(TF_MaximumOp (TF_MinimumOp $arg, $max), $min)>;
def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>;
//===----------------------------------------------------------------------===//
// Variadic op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ConcatenateOp $inputs, $dim),
(TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>;
//===----------------------------------------------------------------------===//
// Compare op patterns.
//===----------------------------------------------------------------------===//
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
[(AreBroadcastCompatible $l, $r)]>;
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;