Add xla_hlo.rsqrt operator and import support

PiperOrigin-RevId: 265808032
This commit is contained in:
A. Unique TensorFlower 2019-08-27 17:55:10 -07:00 committed by TensorFlower Gardener
parent ee61057928
commit bb9ce9acec
5 changed files with 35 additions and 0 deletions

View File

@ -368,6 +368,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
// If dimensions are non-default, the XLA builder implementes it as a
// separate transpose.
NoAttributeCase(kReshape, ReshapeOp);
NoAttributeCase(kRsqrt, RsqrtOp);
NoAttributeCase(kSelect, SelectOp);
NoAttributeCase(kSubtract, SubOp);
NoAttributeCase(kTanh, TanhOp);

View File

@ -125,6 +125,8 @@ def HLO_LogOp: HLO_UnaryElementwiseOp<"log", [NoSideEffect, SameOperandsAndResul
def HLO_NegOp: HLO_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_NegOp;
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RsqrtOp;
def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_SignOp;
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",

View File

@ -116,6 +116,17 @@ class BASE_HLO_NegOp {
}];
}
class BASE_HLO_RsqrtOp {
string summary = "Reciprocal Square-root operator";
string description = [{
Returns `1.0 / sqrt(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_SignOp {
string summary = "Sign operator";

View File

@ -338,6 +338,14 @@ func @log_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> {
// -----
func @rsqrt_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> {
// expected-error@+1 {{'xla_hlo.rsqrt' op requires the same type for all operands and results}}
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<1xf32>) -> tensor<1xi32>
return %0: tensor<1xi32>
}
// -----
// CHECK-LABEL: func @reshape_same_shape
func @reshape_same_shape(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>

View File

@ -0,0 +1,13 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule foo
// CHECK-LABEL: func @main(
// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32> {
ENTRY %foo (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: [[P0:%.+]] = "xla_hlo.rsqrt"([[ARG0]]) {name = "rsqrt.2"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK-NEXT: return [[P0]] : tensor<16xf32>
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
}