From 602e65243dfe2e3b9e98b4e655a6192aac1d5b12 Mon Sep 17 00:00:00 2001 From: Robert Suderman <suderman@google.com> Date: Tue, 26 Nov 2019 11:41:02 -0800 Subject: [PATCH] Lowering for complex operations to float equivalents Other dialects may not natively support complex numbers. Complex operations can be directly computed by decomposing into an equivalent set of floating point operations. This CL includes decompositions including: - add - subtract - multiply - divide - absolute PiperOrigin-RevId: 282608078 Change-Id: I7a77954676e3d0bc45208a6ad3bf98ab0a9aa03e --- tensorflow/compiler/mlir/BUILD | 2 +- tensorflow/compiler/mlir/xla/BUILD | 29 +- tensorflow/compiler/mlir/xla/hlo_utils.h | 9 +- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 6 +- tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 1 + tensorflow/compiler/mlir/xla/ir/hlo_utils.cc | 55 +++ tensorflow/compiler/mlir/xla/ir/hlo_utils.h | 54 +++ tensorflow/compiler/mlir/xla/ir/hlo_utils.td | 37 +++ .../mlir/xla/tests/lower-complex.mlir | 312 ++++++++++++++++++ .../mlir/xla/tests/lower-general-dot.mlir | 2 +- .../mlir/xla/transforms/canonicalize.td | 1 - .../mlir/xla/transforms/legalize_tf.cc | 54 +-- .../xla/transforms/legalize_tf_patterns.td | 12 - .../mlir/xla/transforms/lower_complex.cc | 78 +++++ .../xla/transforms/lower_complex_patterns.td | 121 +++++++ .../mlir/xla/transforms/lower_general_dot.cc | 4 +- .../compiler/mlir/xla/transforms/rewriters.h | 5 + 17 files changed, 704 insertions(+), 78 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/ir/hlo_utils.cc create mode 100644 tensorflow/compiler/mlir/xla/ir/hlo_utils.h create mode 100644 tensorflow/compiler/mlir/xla/ir/hlo_utils.td create mode 100644 tensorflow/compiler/mlir/xla/tests/lower-complex.mlir create mode 100644 tensorflow/compiler/mlir/xla/transforms/lower_complex.cc create mode 100644 tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 59a4f7e7e79..fdd3941e136 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -78,7 +78,7 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", - "//tensorflow/compiler/mlir/xla:xla_lower_general_dot", + "//tensorflow/compiler/mlir/xla:xla_lower", "@local_config_mlir//:AffineDialectRegistration", "@local_config_mlir//:QuantOps", "@local_config_mlir//:QuantOpsDialectRegistration", diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 56a8a8b87f0..3ed3fb6fc40 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -28,6 +28,7 @@ filegroup( srcs = [ "ir/hlo_ops.td", "ir/hlo_ops_base.td", + "ir/hlo_utils.td", "ir/lhlo_ops.td", "@local_config_mlir//:OpBaseTdFiles", ], @@ -43,6 +44,7 @@ gentbl( ], tblgen = "@local_config_mlir//:mlir-tblgen", td_file = "ir/hlo_ops.td", + td_includes = ["ir/hlo_utils.td"], td_srcs = [":hlo_ops_td_files"], ) @@ -232,9 +234,27 @@ cc_library( alwayslink = 1, ) +gentbl( + name = "xla_lower_complex_inc_gen", + tbl_outs = [ + ("-gen-rewriters", "transforms/generated_lower_complex.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "transforms/lower_complex_patterns.td", + td_srcs = [ + ":hlo_ops_td_files", + "@llvm//:support", + "@local_config_mlir//:StdOpsTdFiles", + ], +) + cc_library( - name = "xla_lower_general_dot", - srcs = ["transforms/lower_general_dot.cc"], + name = "xla_lower", + srcs = [ + "transforms/generated_lower_complex.inc", + "transforms/lower_complex.cc", + "transforms/lower_general_dot.cc", + ], deps = [ ":hlo", ":xla_dialect_registration", @@ -243,6 +263,8 @@ cc_library( "@local_config_mlir//:IR", "@local_config_mlir//:Pass", "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:Transforms", ], alwayslink = 1, ) @@ -253,9 +275,11 @@ cc_library( "ir/hlo_ops.cc", "ir/hlo_ops.cc.inc", "ir/hlo_ops.h.inc", + "ir/hlo_utils.cc", ], hdrs = [ "ir/hlo_ops.h", + "ir/hlo_utils.h", "transforms/passes.h", "transforms/rewriters.h", ], @@ -468,6 +492,7 @@ genrule( "@local_config_mlir//:include/mlir/IR/OpBase.td", ":ir/hlo_ops.td", ":ir/hlo_ops_base.td", + ":ir/hlo_utils.td", ], outs = ["operator_writers.inc"], cmd = ("$(location :operator_writer_gen) " + diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h index f706e87c2d3..b267b39ce5a 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/hlo_utils.h @@ -33,8 +33,8 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( const llvm::ArrayRef<int64> vector, mlir::Builder builder); template <typename TypeT> -StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape, - mlir::Builder builder) { +static StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape, + mlir::Builder builder) { auto dimensions = shape.dimensions(); llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end()); @@ -62,8 +62,8 @@ StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape, } template <typename TypeT> -StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape, - mlir::Builder builder) { +static StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape, + mlir::Builder builder) { if (shape.IsTuple()) { mlir::Type mlir_type; llvm::SmallVector<mlir::Type, 4> contents; @@ -77,7 +77,6 @@ StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape, } return ConvertTensorShapeToType<TypeT>(shape, builder); } - } // namespace xla #endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 1bc3d8cb7a5..b996b12119d 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -241,8 +241,8 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) { - return xla::ConvertElementsAttr(elementsAttr, - getElementTypeOrSelf(getResult())); + return ::xla::ConvertElementsAttr(elementsAttr, + getElementTypeOrSelf(getResult())); } return {}; @@ -717,7 +717,7 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y, DenseIntElementsAttr broadcast_dimensions) { auto x_ranked = x.dyn_cast<RankedTensorType>(); auto y_ranked = y.dyn_cast<RankedTensorType>(); - if (!x || !y) { + if (!x_ranked || !y_ranked) { return UnrankedTensorType::get(element_type); } diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index e4acd54c9c9..7bd94589e7f 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -20,6 +20,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" +include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" def HLO_Dialect : Dialect { let name = "xla_hlo"; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc new file mode 100644 index 00000000000..82b7032d542 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc @@ -0,0 +1,55 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" + +#include <numeric> + +namespace mlir { +namespace xla { + +DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x, + Value *y) { + TensorType xType = x->getType().dyn_cast<RankedTensorType>(); + TensorType yType = y->getType().dyn_cast<RankedTensorType>(); + if (xType == yType || !xType || !yType) return {}; + + // If the shapes have the same rank, then there is nothing to do. + auto xRank = xType.getRank(), yRank = yType.getRank(); + if (xRank == yRank) return {}; + + // Otherwise if the ranks of the inputs don't match, TensorFlow automatically + // reshapes the smaller by padding with dimensions of size 1 as a prefix. In + // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to + // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast + // from lower to higher rank, but doesn't assume you want to pad as a prefix + // of the dimensions, and instead needs to be told which dimensions of the + // higher rank tensor to match to the lower rank tensor. + auto maxRank = std::max(xRank, yRank); + auto minRank = std::min(xRank, yRank); + + // Match the lower rank tensor along the larger-numbered dimensions of the + // higher rank tensor. + SmallVector<int64_t, 4> broadcastDimensions(minRank); + std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), + maxRank - minRank); + + RankedTensorType type = + RankedTensorType::get({minRank}, b->getIntegerType(64)); + return DenseIntElementsAttr::get(type, broadcastDimensions); +} + +} // namespace xla +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h new file mode 100644 index 00000000000..d81abf6a0be --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ + +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/convert_op_folder.h" + +namespace mlir { +namespace xla { + +// Computes the broadcast dimensions attr for an elementwise binary operator +// between two ranked tensors. +mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, + mlir::Value* x, + mlir::Value* y); + +/// Get a constant splat for the given value type. +template <typename T> +static ElementsAttr getSplat(Builder* b, Value* val, T constant) { + auto valType = val->getType().cast<TensorType>(); + auto valElementType = getElementTypeOrSelf(val->getType()); + + // Handle integer elements. + Attribute elementAttr; + if (valElementType.isa<IntegerType>()) + elementAttr = b->getIntegerAttr(valElementType, constant); + else if (valElementType.isa<FloatType>()) + elementAttr = b->getFloatAttr(valElementType, constant); + else + llvm_unreachable("unhandled element type"); + + return DenseElementsAttr::get(valType, elementAttr); +} +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td new file mode 100644 index 00000000000..bd1a448b80f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the utils file for the HLO dialect. + +#ifndef HLO_UTILS +#define HLO_UTILS + +#ifndef OP_BASE +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; + +def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">; + +class ConstantSplat<string value> : NativeCodeCall< + "getSplat(&$_builder, $0, " # value # ")">; + +def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; + +def BinBroadcastDimensions : NativeCodeCall< + "getBroadcastDimensionsAttr(&$_builder, $0, $1)">; + +#endif // HLO_UTILS diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir new file mode 100644 index 00000000000..0c0ac91beb0 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -0,0 +1,312 @@ +// RUN: tf-opt %s -test-xla-lower-complex | FileCheck %s + +// CHECK-LABEL: @add +func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 + %4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @add_broadcast +func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> +} + +// CHECK-LABEL: @add_unranked +func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 + %4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @sub +func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3 + %4 = "xla_hlo.sub"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @sub_broadcast +func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.sub"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.sub"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %4 = "xla_hlo.sub"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> +} + +// CHECK-LABEL: @sub_unranked +func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3 + %4 = "xla_hlo.sub"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @mul +func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] + %4 = "xla_hlo.mul"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + + // CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32> + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @mul_broadcast +func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.mul"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.mul"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] + %4 = "xla_hlo.mul"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + + // CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32> + return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> +} + +// CHECK-LABEL: @mul_unranked +func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] + %4 = "xla_hlo.mul"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + + // CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32> + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @div +func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3) + + // Compute the numerator's real component: + // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]] + + // Compute the real valued denominator as rhs * con(rhs): + // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]] + + // Compute the numerator's imaginary component: + // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag + // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] + + // Divide the numerator by the real valued denominator. + // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]] + %4 = "xla_hlo.div"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) + + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + + // CHECK: return [[VAL10]], [[VAL11]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @div_broadcast +func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3) + + // Compute the numerator's real component: + // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.mul"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]] + + // Compute the real valued denominator as rhs * con(rhs): + // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]] + + // Compute the numerator's imaginary component: + // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag + // CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.mul"(%arg1, %arg2) + // CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.mul"(%arg0, [[VAL0]]) + // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] + + // Divide the numerator by the real valued denominator. + // CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.div"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.div"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %4 = "xla_hlo.div"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) + + %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) + + // CHECK: return [[VAL10]], [[VAL11]] + return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: @div_unranked +func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3) + + // Compute the numerator's real component: + // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]] + + // Compute the real valued denominator as rhs * con(rhs): + // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]] + + // Compute the numerator's imaginary component: + // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag + // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] + + // Divide the numerator by the real valued denominator. + // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]] + %4 = "xla_hlo.div"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) + + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + + // CHECK: return [[VAL10]], [[VAL11]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @abs +func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg0 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg1 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]]) + %1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) + %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + + // CHECK: return [[VAL3]] + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: @exp +func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]] + %1 = "xla_hlo.exp"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) + %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + %3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) + + // CHECK: return [[VAL3]], [[VAL4]] + return %2, %3 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @exp_unranked +func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]] + %1 = "xla_hlo.exp"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) + %2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + %3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>) + + // CHECK: return [[VAL3]], [[VAL4]] + return %2, %3 : tensor<*xf32>, tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir b/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir index bbfd5b5ba47..cde55b05c04 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-general-dot.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -xla-lower-general-dot -split-input-file %s -o - | FileCheck %s +// RUN: tf-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s // CHECK-LABEL: @testDebatch1 func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index a39e4961d62..bc44117910b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -18,7 +18,6 @@ limitations under the License. include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" -def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">; //===----------------------------------------------------------------------===// // DynamicSlice op patterns. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index b1d9eb18b35..474a68c3401 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" @@ -274,24 +275,6 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( // Binary op utilities. //===----------------------------------------------------------------------===// -/// Get a constant splat for the given value type. -template <typename T> -static ElementsAttr getSplat(Builder &b, Value *val, T constant) { - auto valType = val->getType().cast<TensorType>(); - auto valElementType = getElementTypeOrSelf(val->getType()); - - // Handle integer elements. - Attribute elementAttr; - if (valElementType.isa<IntegerType>()) - elementAttr = b.getIntegerAttr(valElementType, constant); - else if (valElementType.isa<FloatType>()) - elementAttr = b.getFloatAttr(valElementType, constant); - else - llvm_unreachable("unhandled element type"); - - return DenseElementsAttr::get(valType, elementAttr); -} - // Returns whether the two values are guaranteed to be broadcastable to the // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions // must be broadcasted with a size 1 tensor or another dynamic dimension. @@ -322,37 +305,6 @@ static bool AreBroadcastCompatible(Value *x, Value *y) { return true; } -static DenseIntElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, - Value *y) { - TensorType xType = x->getType().dyn_cast<RankedTensorType>(); - TensorType yType = y->getType().dyn_cast<RankedTensorType>(); - if (xType == yType || !xType || !yType) return {}; - - // If the shapes have the same rank, then there is nothing to do. - auto xRank = xType.getRank(), yRank = yType.getRank(); - if (xRank == yRank) return {}; - - // Otherwise if the ranks of the inputs don't match, TensorFlow automatically - // reshapes the smaller by padding with dimensions of size 1 as a prefix. In - // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to - // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast - // from lower to higher rank, but doesn't assume you want to pad as a prefix - // of the dimensions, and instead needs to be told which dimensions of the - // higher rank tensor to match to the lower rank tensor. - auto maxRank = std::max(xRank, yRank); - auto minRank = std::min(xRank, yRank); - - // Match the lower rank tensor along the larger-numbered dimensions of the - // higher rank tensor. - SmallVector<int64_t, 4> broadcastDimensions(minRank); - std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), - maxRank - minRank); - - RankedTensorType type = - RankedTensorType::get({minRank}, b.getIntegerType(64)); - return DenseIntElementsAttr::get(type, broadcastDimensions); -} - // Return a new TensorType the same rank and dimensions as the input with an // updated element type. static Type ChangeTensorElementType(Builder *b, Type tensor_type, @@ -1017,10 +969,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> { rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create<xla_hlo::MulOp>( op.getLoc(), result_type, iota, op.delta(), - getBroadcastDimensionsAttr(rewriter, iota, op.delta())); + getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); rewriter.replaceOpWithNewOp<xla_hlo::AddOp>( op, result_type, scaled, op.start(), - getBroadcastDimensionsAttr(rewriter, scaled, op.start())); + getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return matchSuccess(); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index ebbc2897c9a..4bf7ee16d0a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -20,14 +20,6 @@ include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" -def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; -def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; - -def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">; - -class ConstantSplat<string value> : NativeCodeCall< - "getSplat($_builder, $0, " # value # ")">; - //===----------------------------------------------------------------------===// // BatchNorm op patterns. //===----------------------------------------------------------------------===// @@ -65,10 +57,6 @@ def : Pat<(TF_BiasAddOp AnyStaticShapeTensor:$input, $bias, $data_format), // Binary op patterns. //===----------------------------------------------------------------------===// -// Get the broadcast dimensions attribute from the binary operands. -def BinBroadcastDimensions : NativeCodeCall< - "getBroadcastDimensionsAttr($_builder, $0, $1)">; - // Check that two values can be broadcasted together def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">, "types must be broadcastable">; diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc new file mode 100644 index 00000000000..e09350f4f74 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex.cc @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Thsi file implements passes to convert complex operations to equivalent real +// value operations. This does not include removing complex values from function +// argument or return types. + +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <numeric> + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +using mlir::FunctionPass; +using mlir::OwningRewritePatternList; +using mlir::PassRegistration; + +namespace { +class LowerComplex : public FunctionPass<LowerComplex> { + public: + explicit LowerComplex() : FunctionPass<LowerComplex>() {} + + /// Performs the lowering to XLA dialect. + void runOnFunction() override; +}; +} // end anonymous namespace + +namespace mlir { +namespace xla { +namespace { + +#include "tensorflow/compiler/mlir/xla/transforms/generated_lower_complex.inc" + +} // end anonymous namespace + +void PopulateComplexLoweringPatterns(MLIRContext* context, + OwningRewritePatternList* patterns) { + populateWithGenerated(context, patterns); +} +} // end namespace xla +} // end namespace mlir + +// Lowers the complex operations that can be represented using other operations. +void LowerComplex::runOnFunction() { + // Add lowering patterns to the list. + OwningRewritePatternList patterns; + mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns); + + applyPatternsGreedily(getFunction(), patterns); +} + +static PassRegistration<LowerComplex> pass( + "test-xla-lower-complex", + "Lower complex operations into non-complex operations"); diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td new file mode 100644 index 00000000000..252a10fc412 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td @@ -0,0 +1,121 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern that converts complex operations into +// equivalent real value operations. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/Ops.td" +include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +// Add and subtraction are elementwise and can be distributed across the real +// and imaginary components. +foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in + def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs, + HLO_ComplexTensor:$rhs, $broadcast_dimensions), + (HLO_ComplexOp + (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), + $broadcast_dimensions), + (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), + $broadcast_dimensions))>; + +// Complex multiplication results in a cross product multiplication between the +// real and imaginary components such that: +// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag +// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, + HLO_ComplexTensor:$rhs, $broadcast_dimensions), + (HLO_ComplexOp + (HLO_SubOp + (HLO_MulOp + (HLO_RealOp:$lhs_real $lhs), + (HLO_RealOp:$rhs_real $rhs), + $broadcast_dimensions), + (HLO_MulOp + (HLO_ImagOp:$lhs_imag $lhs), + (HLO_ImagOp:$rhs_imag $rhs), + $broadcast_dimensions), + (NullDenseIntElementsAttr)), + (HLO_AddOp + (HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions), + (HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions), + (NullDenseIntElementsAttr)))>; + +// Multiplication between a complex and real tensor can be distributed by +// applying the real multiplicant to both the real and complex component. +// +// Note that the sourcep pattern is not legal according to the HLO dialect but +// instead handle intermediates generated by other patterns. +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), + (HLO_ComplexOp + (HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), + (HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + +def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), + (HLO_ComplexOp + (HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions), + (HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>; + + +// Division is performed by normalizing the denominator by multiplying by the +// conjugate of the rhs. +// numerator = lhs * conj(rhs) +// denominator = rhs * conj(rhs) +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), + (HLO_DivOp + (HLO_MulOp:$num $lhs, + (HLO_ComplexOp:$conj + (HLO_RealOp $rhs), + (HLO_NegOp (HLO_ImagOp $rhs))), + $broadcast_dimensions), + (HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)), + (BinBroadcastDimensions $num, $den))>; + + +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), + (HLO_ComplexOp + (HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), + (HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + + +// Absolute value is evaluated as: +// result = sqrt(val.real * val.real + val.imag * val.imag) +def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), + (HLO_ComplexOp + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp (HLO_RealOp:$real $val), $real, + (NullDenseIntElementsAttr)), + (HLO_MulOp (HLO_ImagOp:$imag $val), $imag, + (NullDenseIntElementsAttr)), + (NullDenseIntElementsAttr))), + (HLO_ConstOp (ConstantSplat<"0"> $real)))>; + +// Expononetial can be lowered to an exponential on the real component and a +// sum of sinusoids of the imageinary component, which equates to a normal +// exponential operator multiplied by Euler's formula. +// +// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b)) +def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), + (HLO_MulOp + (HLO_ExpOp (HLO_RealOp $val)), + (HLO_ComplexOp + (HLO_CosOp (HLO_ImagOp:$imag $val)), + (HLO_SinOp $imag)), + (NullDenseIntElementsAttr))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc index 070d00cb718..515f818749e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc @@ -186,5 +186,5 @@ void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns( } static PassRegistration<LegalizeGeneralDot> legalize_pass( - "xla-lower-general-dot", - "Lower a general dot to a non-batched dot when possible"); + "test-xla-lower-general-dot", + "Tests lowering general dot to a non-batched dot when possible"); diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index 46acf4a304e..e4a014f137f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -28,6 +28,11 @@ namespace xla_hlo { void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); +// Collection of rewrite patterns for lowering complex operations to equivalent +// float operations. +void PopulateComplexLoweringPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx);