From bb9ce9acecfc38d7cd6d14d9896ef21823b48641 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 27 Aug 2019 17:55:10 -0700 Subject: [PATCH] Add xla_hlo.rsqrt operator and import support PiperOrigin-RevId: 265808032 --- .../compiler/mlir/xla/hlo_function_importer.cc | 1 + tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 2 ++ tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td | 11 +++++++++++ tensorflow/compiler/mlir/xla/tests/ops.mlir | 8 ++++++++ .../compiler/mlir/xla/tests/translate/rsqrt.hlotxt | 13 +++++++++++++ 5 files changed, 35 insertions(+) create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index babee5e530b..8a69310ced9 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -368,6 +368,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // If dimensions are non-default, the XLA builder implementes it as a // separate transpose. NoAttributeCase(kReshape, ReshapeOp); + NoAttributeCase(kRsqrt, RsqrtOp); NoAttributeCase(kSelect, SelectOp); NoAttributeCase(kSubtract, SubOp); NoAttributeCase(kTanh, TanhOp); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index ecdede5b9fb..7775377c94b 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -125,6 +125,8 @@ def HLO_LogOp: HLO_UnaryElementwiseOp<"log", [NoSideEffect, SameOperandsAndResul def HLO_NegOp: HLO_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_NegOp; +def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RsqrtOp; + def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_SignOp; def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 9fb3ecc2c68..28d6efd0aad 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -116,6 +116,17 @@ class BASE_HLO_NegOp { }]; } +class BASE_HLO_RsqrtOp { + string summary = "Reciprocal Square-root operator"; + + string description = [{ + Returns `1.0 / sqrt(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + class BASE_HLO_SignOp { string summary = "Sign operator"; diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 316a402de5f..06c98fb39b0 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -338,6 +338,14 @@ func @log_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> { // ----- +func @rsqrt_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // expected-error@+1 {{'xla_hlo.rsqrt' op requires the same type for all operands and results}} + %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<1xf32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + // CHECK-LABEL: func @reshape_same_shape func @reshape_same_shape(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt new file mode 100644 index 00000000000..a7b9b73f239 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt @@ -0,0 +1,13 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @main( +// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32> { +ENTRY %foo (arg0.1: f32[16]) -> f32[16] { + %arg0.1 = f32[16] parameter(0) + + // CHECK-NEXT: [[P0:%.+]] = "xla_hlo.rsqrt"([[ARG0]]) {name = "rsqrt.2"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: return [[P0]] : tensor<16xf32> + ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1) +}