Lowering for complex operations to float equivalents

Other dialects may not natively support complex numbers. Complex operations can
be directly computed by decomposing into an equivalent set of floating point
operations. This CL includes decompositions including:
- add
- subtract
- multiply
- divide
- absolute

PiperOrigin-RevId: 282608078
Change-Id: I7a77954676e3d0bc45208a6ad3bf98ab0a9aa03e
This commit is contained in:
Robert Suderman 2019-11-26 11:41:02 -08:00 committed by TensorFlower Gardener
parent 9432026e89
commit 602e65243d
17 changed files with 704 additions and 78 deletions

View File

@ -78,7 +78,7 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower_general_dot",
"//tensorflow/compiler/mlir/xla:xla_lower",
"@local_config_mlir//:AffineDialectRegistration",
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:QuantOpsDialectRegistration",

View File

@ -28,6 +28,7 @@ filegroup(
srcs = [
"ir/hlo_ops.td",
"ir/hlo_ops_base.td",
"ir/hlo_utils.td",
"ir/lhlo_ops.td",
"@local_config_mlir//:OpBaseTdFiles",
],
@ -43,6 +44,7 @@ gentbl(
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "ir/hlo_ops.td",
td_includes = ["ir/hlo_utils.td"],
td_srcs = [":hlo_ops_td_files"],
)
@ -232,9 +234,27 @@ cc_library(
alwayslink = 1,
)
gentbl(
name = "xla_lower_complex_inc_gen",
tbl_outs = [
("-gen-rewriters", "transforms/generated_lower_complex.inc"),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/lower_complex_patterns.td",
td_srcs = [
":hlo_ops_td_files",
"@llvm//:support",
"@local_config_mlir//:StdOpsTdFiles",
],
)
cc_library(
name = "xla_lower_general_dot",
srcs = ["transforms/lower_general_dot.cc"],
name = "xla_lower",
srcs = [
"transforms/generated_lower_complex.inc",
"transforms/lower_complex.cc",
"transforms/lower_general_dot.cc",
],
deps = [
":hlo",
":xla_dialect_registration",
@ -243,6 +263,8 @@ cc_library(
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
],
alwayslink = 1,
)
@ -253,9 +275,11 @@ cc_library(
"ir/hlo_ops.cc",
"ir/hlo_ops.cc.inc",
"ir/hlo_ops.h.inc",
"ir/hlo_utils.cc",
],
hdrs = [
"ir/hlo_ops.h",
"ir/hlo_utils.h",
"transforms/passes.h",
"transforms/rewriters.h",
],
@ -468,6 +492,7 @@ genrule(
"@local_config_mlir//:include/mlir/IR/OpBase.td",
":ir/hlo_ops.td",
":ir/hlo_ops_base.td",
":ir/hlo_utils.td",
],
outs = ["operator_writers.inc"],
cmd = ("$(location :operator_writer_gen) " +

View File

@ -33,7 +33,7 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
const llvm::ArrayRef<int64> vector, mlir::Builder builder);
template <typename TypeT>
StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape,
static StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape,
mlir::Builder builder) {
auto dimensions = shape.dimensions();
llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
@ -62,7 +62,7 @@ StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape,
}
template <typename TypeT>
StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
static StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
mlir::Builder builder) {
if (shape.IsTuple()) {
mlir::Type mlir_type;
@ -77,7 +77,6 @@ StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
}
return ConvertTensorShapeToType<TypeT>(shape, builder);
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_

View File

@ -241,7 +241,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return xla::ConvertElementsAttr(elementsAttr,
return ::xla::ConvertElementsAttr(elementsAttr,
getElementTypeOrSelf(getResult()));
}
@ -717,7 +717,7 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y,
DenseIntElementsAttr broadcast_dimensions) {
auto x_ranked = x.dyn_cast<RankedTensorType>();
auto y_ranked = y.dyn_cast<RankedTensorType>();
if (!x || !y) {
if (!x_ranked || !y_ranked) {
return UnrankedTensorType::get(element_type);
}

View File

@ -20,6 +20,7 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
def HLO_Dialect : Dialect {
let name = "xla_hlo";

View File

@ -0,0 +1,55 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
#include <numeric>
namespace mlir {
namespace xla {
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
Value *y) {
TensorType xType = x->getType().dyn_cast<RankedTensorType>();
TensorType yType = y->getType().dyn_cast<RankedTensorType>();
if (xType == yType || !xType || !yType) return {};
// If the shapes have the same rank, then there is nothing to do.
auto xRank = xType.getRank(), yRank = yType.getRank();
if (xRank == yRank) return {};
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
// from lower to higher rank, but doesn't assume you want to pad as a prefix
// of the dimensions, and instead needs to be told which dimensions of the
// higher rank tensor to match to the lower rank tensor.
auto maxRank = std::max(xRank, yRank);
auto minRank = std::min(xRank, yRank);
// Match the lower rank tensor along the larger-numbered dimensions of the
// higher rank tensor.
SmallVector<int64_t, 4> broadcastDimensions(minRank);
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
maxRank - minRank);
RankedTensorType type =
RankedTensorType::get({minRank}, b->getIntegerType(64));
return DenseIntElementsAttr::get(type, broadcastDimensions);
}
} // namespace xla
} // namespace mlir

View File

@ -0,0 +1,54 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
namespace mlir {
namespace xla {
// Computes the broadcast dimensions attr for an elementwise binary operator
// between two ranked tensors.
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
mlir::Value* x,
mlir::Value* y);
/// Get a constant splat for the given value type.
template <typename T>
static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
auto valType = val->getType().cast<TensorType>();
auto valElementType = getElementTypeOrSelf(val->getType());
// Handle integer elements.
Attribute elementAttr;
if (valElementType.isa<IntegerType>())
elementAttr = b->getIntegerAttr(valElementType, constant);
else if (valElementType.isa<FloatType>())
elementAttr = b->getFloatAttr(valElementType, constant);
else
llvm_unreachable("unhandled element type");
return DenseElementsAttr::get(valType, elementAttr);
}
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the utils file for the HLO dialect.
#ifndef HLO_UTILS
#define HLO_UTILS
#ifndef OP_BASE
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall<
"getSplat(&$_builder, $0, " # value # ")">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall<
"getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
#endif // HLO_UTILS

View File

@ -0,0 +1,312 @@
// RUN: tf-opt %s -test-xla-lower-complex | FileCheck %s
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @add_broadcast
func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// CHECK-LABEL: @add_unranked
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @sub
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3
%4 = "xla_hlo.sub"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @sub_broadcast
func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.sub"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.sub"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%4 = "xla_hlo.sub"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// CHECK-LABEL: @sub_unranked
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3
%4 = "xla_hlo.sub"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @mul
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.mul"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @mul_broadcast
func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.mul"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.mul"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.mul"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32>
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// CHECK-LABEL: @mul_unranked
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.mul"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @div
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]]
%4 = "xla_hlo.div"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// -----
// CHECK-LABEL: @div_broadcast
func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.mul"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.mul"(%arg1, %arg2)
// CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.mul"(%arg0, [[VAL0]])
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.div"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.div"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%4 = "xla_hlo.div"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
}
// -----
// CHECK-LABEL: @div_unranked
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]]
%4 = "xla_hlo.div"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @abs
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg0
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]])
%1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]]
return %2 : tensor<2xf32>
}
// CHECK-LABEL: @exp
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exp"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @exp_unranked
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exp"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<*xf32>, tensor<*xf32>
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
// RUN: tf-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {

View File

@ -18,7 +18,6 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
//===----------------------------------------------------------------------===//
// DynamicSlice op patterns.

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
@ -274,24 +275,6 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
// Binary op utilities.
//===----------------------------------------------------------------------===//
/// Get a constant splat for the given value type.
template <typename T>
static ElementsAttr getSplat(Builder &b, Value *val, T constant) {
auto valType = val->getType().cast<TensorType>();
auto valElementType = getElementTypeOrSelf(val->getType());
// Handle integer elements.
Attribute elementAttr;
if (valElementType.isa<IntegerType>())
elementAttr = b.getIntegerAttr(valElementType, constant);
else if (valElementType.isa<FloatType>())
elementAttr = b.getFloatAttr(valElementType, constant);
else
llvm_unreachable("unhandled element type");
return DenseElementsAttr::get(valType, elementAttr);
}
// Returns whether the two values are guaranteed to be broadcastable to the
// same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
// must be broadcasted with a size 1 tensor or another dynamic dimension.
@ -322,37 +305,6 @@ static bool AreBroadcastCompatible(Value *x, Value *y) {
return true;
}
static DenseIntElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x,
Value *y) {
TensorType xType = x->getType().dyn_cast<RankedTensorType>();
TensorType yType = y->getType().dyn_cast<RankedTensorType>();
if (xType == yType || !xType || !yType) return {};
// If the shapes have the same rank, then there is nothing to do.
auto xRank = xType.getRank(), yRank = yType.getRank();
if (xRank == yRank) return {};
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
// from lower to higher rank, but doesn't assume you want to pad as a prefix
// of the dimensions, and instead needs to be told which dimensions of the
// higher rank tensor to match to the lower rank tensor.
auto maxRank = std::max(xRank, yRank);
auto minRank = std::min(xRank, yRank);
// Match the lower rank tensor along the larger-numbered dimensions of the
// higher rank tensor.
SmallVector<int64_t, 4> broadcastDimensions(minRank);
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
maxRank - minRank);
RankedTensorType type =
RankedTensorType::get({minRank}, b.getIntegerType(64));
return DenseIntElementsAttr::get(type, broadcastDimensions);
}
// Return a new TensorType the same rank and dimensions as the input with an
// updated element type.
static Type ChangeTensorElementType(Builder *b, Type tensor_type,
@ -1017,10 +969,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
rewriter.getI64IntegerAttr(0));
auto scaled = rewriter.create<xla_hlo::MulOp>(
op.getLoc(), result_type, iota, op.delta(),
getBroadcastDimensionsAttr(rewriter, iota, op.delta()));
getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(
op, result_type, scaled, op.start(),
getBroadcastDimensionsAttr(rewriter, scaled, op.start()));
getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
return matchSuccess();
}
};

View File

@ -20,14 +20,6 @@ include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall<
"getSplat($_builder, $0, " # value # ")">;
//===----------------------------------------------------------------------===//
// BatchNorm op patterns.
//===----------------------------------------------------------------------===//
@ -65,10 +57,6 @@ def : Pat<(TF_BiasAddOp AnyStaticShapeTensor:$input, $bias, $data_format),
// Binary op patterns.
//===----------------------------------------------------------------------===//
// Get the broadcast dimensions attribute from the binary operands.
def BinBroadcastDimensions : NativeCodeCall<
"getBroadcastDimensionsAttr($_builder, $0, $1)">;
// Check that two values can be broadcasted together
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
"types must be broadcastable">;

View File

@ -0,0 +1,78 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Thsi file implements passes to convert complex operations to equivalent real
// value operations. This does not include removing complex values from function
// argument or return types.
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
using mlir::FunctionPass;
using mlir::OwningRewritePatternList;
using mlir::PassRegistration;
namespace {
class LowerComplex : public FunctionPass<LowerComplex> {
public:
explicit LowerComplex() : FunctionPass<LowerComplex>() {}
/// Performs the lowering to XLA dialect.
void runOnFunction() override;
};
} // end anonymous namespace
namespace mlir {
namespace xla {
namespace {
#include "tensorflow/compiler/mlir/xla/transforms/generated_lower_complex.inc"
} // end anonymous namespace
void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
populateWithGenerated(context, patterns);
}
} // end namespace xla
} // end namespace mlir
// Lowers the complex operations that can be represented using other operations.
void LowerComplex::runOnFunction() {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
applyPatternsGreedily(getFunction(), patterns);
}
static PassRegistration<LowerComplex> pass(
"test-xla-lower-complex",
"Lower complex operations into non-complex operations");

View File

@ -0,0 +1,121 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern that converts complex operations into
// equivalent real value operations.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//
// Add and subtraction are elementwise and can be distributed across the real
// and imaginary components.
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
(HLO_ComplexOp
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs),
$broadcast_dimensions),
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs),
$broadcast_dimensions))>;
// Complex multiplication results in a cross product multiplication between the
// real and imaginary components such that:
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
(HLO_ComplexOp
(HLO_SubOp
(HLO_MulOp
(HLO_RealOp:$lhs_real $lhs),
(HLO_RealOp:$rhs_real $rhs),
$broadcast_dimensions),
(HLO_MulOp
(HLO_ImagOp:$lhs_imag $lhs),
(HLO_ImagOp:$rhs_imag $rhs),
$broadcast_dimensions),
(NullDenseIntElementsAttr)),
(HLO_AddOp
(HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions),
(HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions),
(NullDenseIntElementsAttr)))>;
// Multiplication between a complex and real tensor can be distributed by
// applying the real multiplicant to both the real and complex component.
//
// Note that the sourcep pattern is not legal according to the HLO dialect but
// instead handle intermediates generated by other patterns.
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
(HLO_ComplexOp
(HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
(HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
(HLO_ComplexOp
(HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions),
(HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>;
// Division is performed by normalizing the denominator by multiplying by the
// conjugate of the rhs.
// numerator = lhs * conj(rhs)
// denominator = rhs * conj(rhs)
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
(HLO_DivOp
(HLO_MulOp:$num $lhs,
(HLO_ComplexOp:$conj
(HLO_RealOp $rhs),
(HLO_NegOp (HLO_ImagOp $rhs))),
$broadcast_dimensions),
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)),
(BinBroadcastDimensions $num, $den))>;
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
(HLO_ComplexOp
(HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
(HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
// Absolute value is evaluated as:
// result = sqrt(val.real * val.real + val.imag * val.imag)
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp (HLO_RealOp:$real $val), $real,
(NullDenseIntElementsAttr)),
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag,
(NullDenseIntElementsAttr)),
(NullDenseIntElementsAttr))),
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
// Expononetial can be lowered to an exponential on the real component and a
// sum of sinusoids of the imageinary component, which equates to a normal
// exponential operator multiplied by Euler's formula.
//
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
(HLO_MulOp
(HLO_ExpOp (HLO_RealOp $val)),
(HLO_ComplexOp
(HLO_CosOp (HLO_ImagOp:$imag $val)),
(HLO_SinOp $imag)),
(NullDenseIntElementsAttr))>;

View File

@ -186,5 +186,5 @@ void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
}
static PassRegistration<LegalizeGeneralDot> legalize_pass(
"xla-lower-general-dot",
"Lower a general dot to a non-batched dot when possible");
"test-xla-lower-general-dot",
"Tests lowering general dot to a non-batched dot when possible");

View File

@ -28,6 +28,11 @@ namespace xla_hlo {
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
// Collection of rewrite patterns for lowering complex operations to equivalent
// float operations.
void PopulateComplexLoweringPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);