Add Atan2, ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical HLO ops.

PiperOrigin-RevId: 279160695
Change-Id: I4da2aa656b1088faba356eaa1eeb5f3cb9e6ba65
This commit is contained in:
Prakalp Srivastava 2019-11-07 13:54:32 -08:00 committed by TensorFlower Gardener
parent 0fd80986cb
commit 352c88a315
6 changed files with 136 additions and 5 deletions

View File

@ -438,6 +438,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
// builder API.
NoAttributeCase(kAdd, AddOp);
NoAttributeCase(kAnd, AndOp);
NoAttributeCase(kAtan2, Atan2Op);
NoAttributeCase(kConvert, ConvertOp);
NoAttributeCase(kClamp, ClampOp);
NoAttributeCase(kComplex, ComplexOp);
@ -461,6 +462,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
NoAttributeCase(kReshape, ReshapeOp);
NoAttributeCase(kRsqrt, RsqrtOp);
NoAttributeCase(kSelect, SelectOp);
NoAttributeCase(kShiftLeft, ShiftLeftOp);
NoAttributeCase(kShiftRightArithmetic, ShiftRightArithmeticOp);
NoAttributeCase(kShiftRightLogical, ShiftRightLogicalOp);
NoAttributeCase(kSin, SinOp);
NoAttributeCase(kSubtract, SubOp);
NoAttributeCase(kTanh, TanhOp);

View File

@ -773,16 +773,20 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y,
}
BINARY_BUILDER(AddOp);
BINARY_BUILDER(SubOp);
BINARY_BUILDER(MulOp);
BINARY_BUILDER(AndOp);
BINARY_BUILDER(Atan2Op);
BINARY_BUILDER(DivOp);
BINARY_BUILDER(MaxOp);
BINARY_BUILDER(MinOp);
BINARY_BUILDER(AndOp);
BINARY_BUILDER(MulOp);
BINARY_BUILDER(OrOp);
BINARY_BUILDER(XorOp);
BINARY_BUILDER(RemOp);
BINARY_BUILDER(PowOp);
BINARY_BUILDER(RemOp);
BINARY_BUILDER(ShiftLeftOp);
BINARY_BUILDER(ShiftRightArithmeticOp);
BINARY_BUILDER(ShiftRightLogicalOp);
BINARY_BUILDER(SubOp);
BINARY_BUILDER(XorOp);
#undef BINARY_BUILDER

View File

@ -228,6 +228,9 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp;
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op;
def HLO_DivOp : HLO_BinaryElementwiseOp<"div",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp;
@ -246,6 +249,15 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"pow",
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp;
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp;
def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp;
def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp;
def HLO_SubOp : HLO_BinaryElementwiseOp<"sub",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp;

View File

@ -321,6 +321,50 @@ class BASE_HLO_SubOp {
}];
}
class BASE_HLO_ShiftLeftOp {
string summary = "Shift Left operator";
string description = [{
Returns `lhs << rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_ShiftRightArithmeticOp {
string summary = "Shift right arithmetic operator";
string description = [{
Returns arithmetic `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_ShiftRightLogicalOp {
string summary = "Shift right logical operator";
string description = [{
Returns logical `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_Atan2Op {
string summary = "Atan2 operator";
string description = [{
Returns `atan2(lhs/rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
class BASE_HLO_AndOp {
string summary = "Logical and";

View File

@ -0,0 +1,23 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
module {
func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
// CHECK-LABEL: ROOT
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
return %0, %1, %2, %3 : tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
}

View File

@ -53,6 +53,17 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %and.3 = pred[4] and(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_atan2
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_atan2 (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: [[VAL_2:%.*]] = xla_hlo.atan2 [[VAL_0]], [[VAL_1]]
// CHECK: return [[VAL_2]] : tensor<4xi32>
ROOT %atan2 = s32[4] atan2(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_broadcast_in_dim
%test_broadcast_in_dim {
%Arg_0.1 = f32[1, 2] parameter(0)
@ -632,3 +643,36 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
// CHECK: return [[VAL_2]] : tensor<4xi1>
ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_shiftleft
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_shiftleft (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: [[VAL_2:%.*]] = xla_hlo.shift_left [[VAL_0]], [[VAL_1]]
// CHECK: return [[VAL_2]] : tensor<4xi32>
ROOT %shiftleft = s32[4] shift-left(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_shiftright_arithmetic
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_shiftright_arithmetic (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: [[VAL_2:%.*]] = xla_hlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]]
// CHECK: return [[VAL_2]] : tensor<4xi32>
ROOT %shiftright.arithmetic = s32[4] shift-right-arithmetic(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_shiftright_logical
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_shiftright_logical (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: [[VAL_2:%.*]] = xla_hlo.shift_right_logical [[VAL_0]], [[VAL_1]]
// CHECK: return [[VAL_2]] : tensor<4xi32>
ROOT %shiftright.logical = s32[4] shift-right-logical(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}