diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index a72e244e58e..29314d4e6a7 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -465,6 +465,7 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kSubtract, SubOp); NoAttributeCase(kTanh, TanhOp); NoAttributeCase(kTuple, TupleOp); + NoAttributeCase(kXor, XorOp); // TODO(b/129422361) Copy needs special handling because it is not defined // in tensorflow/compiler/xla/client/xla_builder.h. // See operation semantics in diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 6fc09e52725..e0830da5e81 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -249,7 +249,6 @@ def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", def HLO_SubOp : HLO_BinaryElementwiseOp<"sub", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; - //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// @@ -266,6 +265,7 @@ class HLO_BinaryLogicalElementwiseOp : def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; +def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; //===----------------------------------------------------------------------===// // XLA control flow op definitions. diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 8e03cc83516..e34d024b1cf 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -343,6 +343,17 @@ class BASE_HLO_OrOp { }]; } +class BASE_HLO_XorOp { + string summary = "Logical xor"; + + string description = [{ + Returns `lhs xor rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + class BASE_HLO_ReduceOp { string summary = "Reduce operator"; diff --git a/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt index a0881a5bbc1..08a878146fd 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/ops.hlotxt @@ -621,3 +621,14 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { // CHECK-NEXT: return %0 : tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond } + +// CHECK-LABEL: func @test_xor +// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi1>, [[VAL_1:%.*]]: tensor<4xi1>) -> tensor<4xi1> +%test_xor (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] { + %Arg_0.1 = pred[4] parameter(0) + %Arg_1.2 = pred[4] parameter(1) + + // CHECK: [[VAL_2:%.*]] = xla_hlo.xor [[VAL_0]], [[VAL_1]] + // CHECK: return [[VAL_2]] : tensor<4xi1> + ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir b/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir new file mode 100644 index 00000000000..3ad79d633c7 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/xor.mlir @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +module { + func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: [[VAL_1:%.*]] = pred[4] parameter(0) + // CHECK: [[VAL_2:%.*]] = pred[4] parameter(1) + %0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1> + // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]]) + return %0 : tensor<4xi1> + } +}