Add Atan2, ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical HLO ops.
PiperOrigin-RevId: 279160695 Change-Id: I4da2aa656b1088faba356eaa1eeb5f3cb9e6ba65
This commit is contained in:
parent
0fd80986cb
commit
352c88a315
@ -438,6 +438,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
// builder API.
|
||||
NoAttributeCase(kAdd, AddOp);
|
||||
NoAttributeCase(kAnd, AndOp);
|
||||
NoAttributeCase(kAtan2, Atan2Op);
|
||||
NoAttributeCase(kConvert, ConvertOp);
|
||||
NoAttributeCase(kClamp, ClampOp);
|
||||
NoAttributeCase(kComplex, ComplexOp);
|
||||
@ -461,6 +462,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
NoAttributeCase(kReshape, ReshapeOp);
|
||||
NoAttributeCase(kRsqrt, RsqrtOp);
|
||||
NoAttributeCase(kSelect, SelectOp);
|
||||
NoAttributeCase(kShiftLeft, ShiftLeftOp);
|
||||
NoAttributeCase(kShiftRightArithmetic, ShiftRightArithmeticOp);
|
||||
NoAttributeCase(kShiftRightLogical, ShiftRightLogicalOp);
|
||||
NoAttributeCase(kSin, SinOp);
|
||||
NoAttributeCase(kSubtract, SubOp);
|
||||
NoAttributeCase(kTanh, TanhOp);
|
||||
|
@ -773,16 +773,20 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y,
|
||||
}
|
||||
|
||||
BINARY_BUILDER(AddOp);
|
||||
BINARY_BUILDER(SubOp);
|
||||
BINARY_BUILDER(MulOp);
|
||||
BINARY_BUILDER(AndOp);
|
||||
BINARY_BUILDER(Atan2Op);
|
||||
BINARY_BUILDER(DivOp);
|
||||
BINARY_BUILDER(MaxOp);
|
||||
BINARY_BUILDER(MinOp);
|
||||
BINARY_BUILDER(AndOp);
|
||||
BINARY_BUILDER(MulOp);
|
||||
BINARY_BUILDER(OrOp);
|
||||
BINARY_BUILDER(XorOp);
|
||||
BINARY_BUILDER(RemOp);
|
||||
BINARY_BUILDER(PowOp);
|
||||
BINARY_BUILDER(RemOp);
|
||||
BINARY_BUILDER(ShiftLeftOp);
|
||||
BINARY_BUILDER(ShiftRightArithmeticOp);
|
||||
BINARY_BUILDER(ShiftRightLogicalOp);
|
||||
BINARY_BUILDER(SubOp);
|
||||
BINARY_BUILDER(XorOp);
|
||||
|
||||
#undef BINARY_BUILDER
|
||||
|
||||
|
@ -228,6 +228,9 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp;
|
||||
|
||||
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op;
|
||||
|
||||
def HLO_DivOp : HLO_BinaryElementwiseOp<"div",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp;
|
||||
|
||||
@ -246,6 +249,15 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"pow",
|
||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp;
|
||||
|
||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp;
|
||||
|
||||
def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp;
|
||||
|
||||
def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp;
|
||||
|
||||
def HLO_SubOp : HLO_BinaryElementwiseOp<"sub",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp;
|
||||
|
||||
|
@ -321,6 +321,50 @@ class BASE_HLO_SubOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ShiftLeftOp {
|
||||
string summary = "Shift Left operator";
|
||||
|
||||
string description = [{
|
||||
Returns `lhs << rhs` element-wise.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ShiftRightArithmeticOp {
|
||||
string summary = "Shift right arithmetic operator";
|
||||
|
||||
string description = [{
|
||||
Returns arithmetic `lhs >> rhs` element-wise.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ShiftRightLogicalOp {
|
||||
string summary = "Shift right logical operator";
|
||||
|
||||
string description = [{
|
||||
Returns logical `lhs >> rhs` element-wise.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_Atan2Op {
|
||||
string summary = "Atan2 operator";
|
||||
|
||||
string description = [{
|
||||
Returns `atan2(lhs/rhs)` element-wise.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_AndOp {
|
||||
string summary = "Logical and";
|
||||
|
||||
|
@ -0,0 +1,23 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
module {
|
||||
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
|
||||
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
|
||||
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
}
|
@ -53,6 +53,17 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %and.3 = pred[4] and(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_atan2
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
|
||||
%test_atan2 (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: [[VAL_2:%.*]] = xla_hlo.atan2 [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: return [[VAL_2]] : tensor<4xi32>
|
||||
ROOT %atan2 = s32[4] atan2(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_broadcast_in_dim
|
||||
%test_broadcast_in_dim {
|
||||
%Arg_0.1 = f32[1, 2] parameter(0)
|
||||
@ -632,3 +643,36 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
// CHECK: return [[VAL_2]] : tensor<4xi1>
|
||||
ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_shiftleft
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
|
||||
%test_shiftleft (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: [[VAL_2:%.*]] = xla_hlo.shift_left [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: return [[VAL_2]] : tensor<4xi32>
|
||||
ROOT %shiftleft = s32[4] shift-left(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_shiftright_arithmetic
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
|
||||
%test_shiftright_arithmetic (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: [[VAL_2:%.*]] = xla_hlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: return [[VAL_2]] : tensor<4xi32>
|
||||
ROOT %shiftright.arithmetic = s32[4] shift-right-arithmetic(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_shiftright_logical
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
|
||||
%test_shiftright_logical (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: [[VAL_2:%.*]] = xla_hlo.shift_right_logical [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: return [[VAL_2]] : tensor<4xi32>
|
||||
ROOT %shiftright.logical = s32[4] shift-right-logical(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user