Add xla_hlo.rsqrt operator and import support
PiperOrigin-RevId: 265808032
This commit is contained in:
parent
ee61057928
commit
bb9ce9acec
@ -368,6 +368,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
// If dimensions are non-default, the XLA builder implementes it as a
|
||||
// separate transpose.
|
||||
NoAttributeCase(kReshape, ReshapeOp);
|
||||
NoAttributeCase(kRsqrt, RsqrtOp);
|
||||
NoAttributeCase(kSelect, SelectOp);
|
||||
NoAttributeCase(kSubtract, SubOp);
|
||||
NoAttributeCase(kTanh, TanhOp);
|
||||
|
@ -125,6 +125,8 @@ def HLO_LogOp: HLO_UnaryElementwiseOp<"log", [NoSideEffect, SameOperandsAndResul
|
||||
|
||||
def HLO_NegOp: HLO_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_NegOp;
|
||||
|
||||
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RsqrtOp;
|
||||
|
||||
def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_SignOp;
|
||||
|
||||
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
||||
|
@ -116,6 +116,17 @@ class BASE_HLO_NegOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_RsqrtOp {
|
||||
string summary = "Reciprocal Square-root operator";
|
||||
|
||||
string description = [{
|
||||
Returns `1.0 / sqrt(operand)` element-wise.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_SignOp {
|
||||
string summary = "Sign operator";
|
||||
|
||||
|
@ -338,6 +338,14 @@ func @log_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @rsqrt_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// expected-error@+1 {{'xla_hlo.rsqrt' op requires the same type for all operands and results}}
|
||||
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<1xf32>) -> tensor<1xi32>
|
||||
return %0: tensor<1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reshape_same_shape
|
||||
func @reshape_same_shape(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
|
13
tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt
Normal file
13
tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt
Normal file
@ -0,0 +1,13 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule foo
|
||||
|
||||
// CHECK-LABEL: func @main(
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32> {
|
||||
ENTRY %foo (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: [[P0:%.+]] = "xla_hlo.rsqrt"([[ARG0]]) {name = "rsqrt.2"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: return [[P0]] : tensor<16xf32>
|
||||
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user