From 4e9394c47845b23981763f77a930857eea65dbb3 Mon Sep 17 00:00:00 2001
From: Prakalp Srivastava <prakalps@google.com>
Date: Wed, 6 Nov 2019 18:18:23 -0800
Subject: [PATCH] Add xla_hlo xor op.

PiperOrigin-RevId: 278982559
Change-Id: I2c5d67da57fb2d016c126ee4cd9a504cf72ae9ec
---
 tensorflow/compiler/mlir/xla/hlo_function_importer.cc |  1 +
 tensorflow/compiler/mlir/xla/ir/hlo_ops.td            |  2 +-
 tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td       | 11 +++++++++++
 .../compiler/mlir/xla/tests/translate/ops.hlotxt      | 11 +++++++++++
 tensorflow/compiler/mlir/xla/tests/translate/xor.mlir | 11 +++++++++++
 5 files changed, 35 insertions(+), 1 deletion(-)
 create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/xor.mlir

diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index a72e244e58e..29314d4e6a7 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -465,6 +465,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
       NoAttributeCase(kSubtract, SubOp);
       NoAttributeCase(kTanh, TanhOp);
       NoAttributeCase(kTuple, TupleOp);
+      NoAttributeCase(kXor, XorOp);
       // TODO(b/129422361) Copy needs special handling because it is not defined
       // in tensorflow/compiler/xla/client/xla_builder.h.
       // See operation semantics in
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 6fc09e52725..e0830da5e81 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -249,7 +249,6 @@ def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
 def HLO_SubOp : HLO_BinaryElementwiseOp<"sub",
       [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp;
 
-
 //===----------------------------------------------------------------------===//
 // XLA binary elementwise op definitions.
 //===----------------------------------------------------------------------===//
@@ -266,6 +265,7 @@ class HLO_BinaryLogicalElementwiseOp<string mnemonic> :
 
 def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp;
 def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp;
+def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
 
 //===----------------------------------------------------------------------===//
 // XLA control flow op definitions.
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index 8e03cc83516..e34d024b1cf 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -343,6 +343,17 @@ class BASE_HLO_OrOp {
   }];
 }
 
+class BASE_HLO_XorOp {
+  string summary = "Logical xor";
+
+  string description = [{
+    Returns `lhs xor rhs` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
+  }];
+}
+
 class BASE_HLO_ReduceOp {
   string summary = "Reduce operator";
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt
index a0881a5bbc1..08a878146fd 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt
@@ -621,3 +621,14 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
   // CHECK-NEXT: return %0 : tensor<i64>
   ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
 }
+
+// CHECK-LABEL:   func @test_xor
+// CHECK-SAME:    ([[VAL_0:%.*]]: tensor<4xi1>, [[VAL_1:%.*]]: tensor<4xi1>) -> tensor<4xi1>
+%test_xor (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] {
+  %Arg_0.1 = pred[4] parameter(0)
+  %Arg_1.2 = pred[4] parameter(1)
+
+  // CHECK:  [[VAL_2:%.*]] = xla_hlo.xor [[VAL_0]], [[VAL_1]]
+  // CHECK:  return [[VAL_2]] : tensor<4xi1>
+  ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir b/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir
new file mode 100644
index 00000000000..3ad79d633c7
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir
@@ -0,0 +1,11 @@
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+module {
+  func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
+    // CHECK:    [[VAL_1:%.*]] = pred[4] parameter(0)
+    // CHECK:    [[VAL_2:%.*]] = pred[4] parameter(1)
+    %0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
+    // CHECK:   ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
+    return %0 : tensor<4xi1>
+  }
+}