Lowering for complex operations to float equivalents
Other dialects may not natively support complex numbers. Complex operations can be directly computed by decomposing into an equivalent set of floating point operations. This CL includes decompositions including: - add - subtract - multiply - divide - absolute PiperOrigin-RevId: 282608078 Change-Id: I7a77954676e3d0bc45208a6ad3bf98ab0a9aa03e
This commit is contained in:
parent
9432026e89
commit
602e65243d
@ -78,7 +78,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower_general_dot",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"@local_config_mlir//:AffineDialectRegistration",
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:QuantOpsDialectRegistration",
|
||||
|
@ -28,6 +28,7 @@ filegroup(
|
||||
srcs = [
|
||||
"ir/hlo_ops.td",
|
||||
"ir/hlo_ops_base.td",
|
||||
"ir/hlo_utils.td",
|
||||
"ir/lhlo_ops.td",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
],
|
||||
@ -43,6 +44,7 @@ gentbl(
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "ir/hlo_ops.td",
|
||||
td_includes = ["ir/hlo_utils.td"],
|
||||
td_srcs = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
@ -232,9 +234,27 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_lower_complex_inc_gen",
|
||||
tbl_outs = [
|
||||
("-gen-rewriters", "transforms/generated_lower_complex.inc"),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/lower_complex_patterns.td",
|
||||
td_srcs = [
|
||||
":hlo_ops_td_files",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_lower_general_dot",
|
||||
srcs = ["transforms/lower_general_dot.cc"],
|
||||
name = "xla_lower",
|
||||
srcs = [
|
||||
"transforms/generated_lower_complex.inc",
|
||||
"transforms/lower_complex.cc",
|
||||
"transforms/lower_general_dot.cc",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":xla_dialect_registration",
|
||||
@ -243,6 +263,8 @@ cc_library(
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -253,9 +275,11 @@ cc_library(
|
||||
"ir/hlo_ops.cc",
|
||||
"ir/hlo_ops.cc.inc",
|
||||
"ir/hlo_ops.h.inc",
|
||||
"ir/hlo_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/hlo_ops.h",
|
||||
"ir/hlo_utils.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/rewriters.h",
|
||||
],
|
||||
@ -468,6 +492,7 @@ genrule(
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
":ir/hlo_ops.td",
|
||||
":ir/hlo_ops_base.td",
|
||||
":ir/hlo_utils.td",
|
||||
],
|
||||
outs = ["operator_writers.inc"],
|
||||
cmd = ("$(location :operator_writer_gen) " +
|
||||
|
@ -33,7 +33,7 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||
const llvm::ArrayRef<int64> vector, mlir::Builder builder);
|
||||
|
||||
template <typename TypeT>
|
||||
StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape,
|
||||
static StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape,
|
||||
mlir::Builder builder) {
|
||||
auto dimensions = shape.dimensions();
|
||||
llvm::SmallVector<int64_t, 4> array(dimensions.begin(), dimensions.end());
|
||||
@ -62,7 +62,7 @@ StatusOr<TypeT> ConvertTensorShapeToType(const Shape& shape,
|
||||
}
|
||||
|
||||
template <typename TypeT>
|
||||
StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
|
||||
static StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
|
||||
mlir::Builder builder) {
|
||||
if (shape.IsTuple()) {
|
||||
mlir::Type mlir_type;
|
||||
@ -77,7 +77,6 @@ StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
|
||||
}
|
||||
return ConvertTensorShapeToType<TypeT>(shape, builder);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_
|
||||
|
@ -241,7 +241,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
// If the operand is constant, we can do the conversion now.
|
||||
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
||||
return xla::ConvertElementsAttr(elementsAttr,
|
||||
return ::xla::ConvertElementsAttr(elementsAttr,
|
||||
getElementTypeOrSelf(getResult()));
|
||||
}
|
||||
|
||||
@ -717,7 +717,7 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y,
|
||||
DenseIntElementsAttr broadcast_dimensions) {
|
||||
auto x_ranked = x.dyn_cast<RankedTensorType>();
|
||||
auto y_ranked = y.dyn_cast<RankedTensorType>();
|
||||
if (!x || !y) {
|
||||
if (!x_ranked || !y_ranked) {
|
||||
return UnrankedTensorType::get(element_type);
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "xla_hlo";
|
||||
|
55
tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
Normal file
55
tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
|
||||
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
|
||||
Value *y) {
|
||||
TensorType xType = x->getType().dyn_cast<RankedTensorType>();
|
||||
TensorType yType = y->getType().dyn_cast<RankedTensorType>();
|
||||
if (xType == yType || !xType || !yType) return {};
|
||||
|
||||
// If the shapes have the same rank, then there is nothing to do.
|
||||
auto xRank = xType.getRank(), yRank = yType.getRank();
|
||||
if (xRank == yRank) return {};
|
||||
|
||||
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
|
||||
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
|
||||
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
|
||||
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
|
||||
// from lower to higher rank, but doesn't assume you want to pad as a prefix
|
||||
// of the dimensions, and instead needs to be told which dimensions of the
|
||||
// higher rank tensor to match to the lower rank tensor.
|
||||
auto maxRank = std::max(xRank, yRank);
|
||||
auto minRank = std::min(xRank, yRank);
|
||||
|
||||
// Match the lower rank tensor along the larger-numbered dimensions of the
|
||||
// higher rank tensor.
|
||||
SmallVector<int64_t, 4> broadcastDimensions(minRank);
|
||||
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
|
||||
maxRank - minRank);
|
||||
|
||||
RankedTensorType type =
|
||||
RankedTensorType::get({minRank}, b->getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
54
tensorflow/compiler/mlir/xla/ir/hlo_utils.h
Normal file
54
tensorflow/compiler/mlir/xla/ir/hlo_utils.h
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
|
||||
// Computes the broadcast dimensions attr for an elementwise binary operator
|
||||
// between two ranked tensors.
|
||||
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
|
||||
mlir::Value* x,
|
||||
mlir::Value* y);
|
||||
|
||||
/// Get a constant splat for the given value type.
|
||||
template <typename T>
|
||||
static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
|
||||
auto valType = val->getType().cast<TensorType>();
|
||||
auto valElementType = getElementTypeOrSelf(val->getType());
|
||||
|
||||
// Handle integer elements.
|
||||
Attribute elementAttr;
|
||||
if (valElementType.isa<IntegerType>())
|
||||
elementAttr = b->getIntegerAttr(valElementType, constant);
|
||||
else if (valElementType.isa<FloatType>())
|
||||
elementAttr = b->getFloatAttr(valElementType, constant);
|
||||
else
|
||||
llvm_unreachable("unhandled element type");
|
||||
|
||||
return DenseElementsAttr::get(valType, elementAttr);
|
||||
}
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
37
tensorflow/compiler/mlir/xla/ir/hlo_utils.td
Normal file
37
tensorflow/compiler/mlir/xla/ir/hlo_utils.td
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the utils file for the HLO dialect.
|
||||
|
||||
#ifndef HLO_UTILS
|
||||
#define HLO_UTILS
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||
|
||||
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||
|
||||
class ConstantSplat<string value> : NativeCodeCall<
|
||||
"getSplat(&$_builder, $0, " # value # ")">;
|
||||
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
"getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||
|
||||
#endif // HLO_UTILS
|
312
tensorflow/compiler/mlir/xla/tests/lower-complex.mlir
Normal file
312
tensorflow/compiler/mlir/xla/tests/lower-complex.mlir
Normal file
@ -0,0 +1,312 @@
|
||||
// RUN: tf-opt %s -test-xla-lower-complex | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @add
|
||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
|
||||
%4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_broadcast
|
||||
func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_unranked
|
||||
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
|
||||
%4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub
|
||||
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3
|
||||
%4 = "xla_hlo.sub"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub_broadcast
|
||||
func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.sub"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.sub"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%4 = "xla_hlo.sub"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub_unranked
|
||||
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.sub %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.sub %arg1, %arg3
|
||||
%4 = "xla_hlo.sub"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul
|
||||
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.mul"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_broadcast
|
||||
func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.mul"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.mul"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.mul"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_unranked
|
||||
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.sub [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.mul"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div
|
||||
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]]
|
||||
%4 = "xla_hlo.div"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @div_broadcast
|
||||
func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.mul"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.mul"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.mul"(%arg1, %arg2)
|
||||
// CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.mul"(%arg0, [[VAL0]])
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.div"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.div"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%4 = "xla_hlo.div"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @div_unranked
|
||||
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.neg"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.mul %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.sub [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.mul %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.sub [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.mul %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.mul %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.div [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.div [[VAL9]], [[VAL6]]
|
||||
%4 = "xla_hlo.div"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @abs
|
||||
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.mul %arg0, %arg0
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.mul %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]])
|
||||
%1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]]
|
||||
return %2 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp
|
||||
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]]
|
||||
%1 = "xla_hlo.exp"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp_unranked
|
||||
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exp"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cos"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sin"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.mul [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.mul [[VAL0]], [[VAL2]]
|
||||
%1 = "xla_hlo.exp"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
|
||||
// RUN: tf-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @testDebatch1
|
||||
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicSlice op patterns.
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||
@ -274,24 +275,6 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
|
||||
// Binary op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Get a constant splat for the given value type.
|
||||
template <typename T>
|
||||
static ElementsAttr getSplat(Builder &b, Value *val, T constant) {
|
||||
auto valType = val->getType().cast<TensorType>();
|
||||
auto valElementType = getElementTypeOrSelf(val->getType());
|
||||
|
||||
// Handle integer elements.
|
||||
Attribute elementAttr;
|
||||
if (valElementType.isa<IntegerType>())
|
||||
elementAttr = b.getIntegerAttr(valElementType, constant);
|
||||
else if (valElementType.isa<FloatType>())
|
||||
elementAttr = b.getFloatAttr(valElementType, constant);
|
||||
else
|
||||
llvm_unreachable("unhandled element type");
|
||||
|
||||
return DenseElementsAttr::get(valType, elementAttr);
|
||||
}
|
||||
|
||||
// Returns whether the two values are guaranteed to be broadcastable to the
|
||||
// same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
|
||||
// must be broadcasted with a size 1 tensor or another dynamic dimension.
|
||||
@ -322,37 +305,6 @@ static bool AreBroadcastCompatible(Value *x, Value *y) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static DenseIntElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x,
|
||||
Value *y) {
|
||||
TensorType xType = x->getType().dyn_cast<RankedTensorType>();
|
||||
TensorType yType = y->getType().dyn_cast<RankedTensorType>();
|
||||
if (xType == yType || !xType || !yType) return {};
|
||||
|
||||
// If the shapes have the same rank, then there is nothing to do.
|
||||
auto xRank = xType.getRank(), yRank = yType.getRank();
|
||||
if (xRank == yRank) return {};
|
||||
|
||||
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
|
||||
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
|
||||
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
|
||||
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
|
||||
// from lower to higher rank, but doesn't assume you want to pad as a prefix
|
||||
// of the dimensions, and instead needs to be told which dimensions of the
|
||||
// higher rank tensor to match to the lower rank tensor.
|
||||
auto maxRank = std::max(xRank, yRank);
|
||||
auto minRank = std::min(xRank, yRank);
|
||||
|
||||
// Match the lower rank tensor along the larger-numbered dimensions of the
|
||||
// higher rank tensor.
|
||||
SmallVector<int64_t, 4> broadcastDimensions(minRank);
|
||||
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
|
||||
maxRank - minRank);
|
||||
|
||||
RankedTensorType type =
|
||||
RankedTensorType::get({minRank}, b.getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
||||
}
|
||||
|
||||
// Return a new TensorType the same rank and dimensions as the input with an
|
||||
// updated element type.
|
||||
static Type ChangeTensorElementType(Builder *b, Type tensor_type,
|
||||
@ -1017,10 +969,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
auto scaled = rewriter.create<xla_hlo::MulOp>(
|
||||
op.getLoc(), result_type, iota, op.delta(),
|
||||
getBroadcastDimensionsAttr(rewriter, iota, op.delta()));
|
||||
getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(
|
||||
op, result_type, scaled, op.start(),
|
||||
getBroadcastDimensionsAttr(rewriter, scaled, op.start()));
|
||||
getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -20,14 +20,6 @@ include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||
|
||||
class ConstantSplat<string value> : NativeCodeCall<
|
||||
"getSplat($_builder, $0, " # value # ")">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BatchNorm op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -65,10 +57,6 @@ def : Pat<(TF_BiasAddOp AnyStaticShapeTensor:$input, $bias, $data_format),
|
||||
// Binary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Get the broadcast dimensions attribute from the binary operands.
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
"getBroadcastDimensionsAttr($_builder, $0, $1)">;
|
||||
|
||||
// Check that two values can be broadcasted together
|
||||
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
|
||||
"types must be broadcastable">;
|
||||
|
78
tensorflow/compiler/mlir/xla/transforms/lower_complex.cc
Normal file
78
tensorflow/compiler/mlir/xla/transforms/lower_complex.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Thsi file implements passes to convert complex operations to equivalent real
|
||||
// value operations. This does not include removing complex values from function
|
||||
// argument or return types.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
||||
using mlir::FunctionPass;
|
||||
using mlir::OwningRewritePatternList;
|
||||
using mlir::PassRegistration;
|
||||
|
||||
namespace {
|
||||
class LowerComplex : public FunctionPass<LowerComplex> {
|
||||
public:
|
||||
explicit LowerComplex() : FunctionPass<LowerComplex>() {}
|
||||
|
||||
/// Performs the lowering to XLA dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_lower_complex.inc"
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void PopulateComplexLoweringPatterns(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
} // end namespace xla
|
||||
} // end namespace mlir
|
||||
|
||||
// Lowers the complex operations that can be represented using other operations.
|
||||
void LowerComplex::runOnFunction() {
|
||||
// Add lowering patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LowerComplex> pass(
|
||||
"test-xla-lower-complex",
|
||||
"Lower complex operations into non-complex operations");
|
@ -0,0 +1,121 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the legalization pattern that converts complex operations into
|
||||
// equivalent real value operations.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Add and subtraction are elementwise and can be distributed across the real
|
||||
// and imaginary components.
|
||||
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
|
||||
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
|
||||
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
(HLO_ComplexOp
|
||||
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs),
|
||||
$broadcast_dimensions),
|
||||
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs),
|
||||
$broadcast_dimensions))>;
|
||||
|
||||
// Complex multiplication results in a cross product multiplication between the
|
||||
// real and imaginary components such that:
|
||||
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
|
||||
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
|
||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
|
||||
HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
(HLO_ComplexOp
|
||||
(HLO_SubOp
|
||||
(HLO_MulOp
|
||||
(HLO_RealOp:$lhs_real $lhs),
|
||||
(HLO_RealOp:$rhs_real $rhs),
|
||||
$broadcast_dimensions),
|
||||
(HLO_MulOp
|
||||
(HLO_ImagOp:$lhs_imag $lhs),
|
||||
(HLO_ImagOp:$rhs_imag $rhs),
|
||||
$broadcast_dimensions),
|
||||
(NullDenseIntElementsAttr)),
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions),
|
||||
(HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions),
|
||||
(NullDenseIntElementsAttr)))>;
|
||||
|
||||
// Multiplication between a complex and real tensor can be distributed by
|
||||
// applying the real multiplicant to both the real and complex component.
|
||||
//
|
||||
// Note that the sourcep pattern is not legal according to the HLO dialect but
|
||||
// instead handle intermediates generated by other patterns.
|
||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
|
||||
(HLO_ComplexOp
|
||||
(HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
|
||||
(HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
|
||||
|
||||
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
(HLO_ComplexOp
|
||||
(HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions),
|
||||
(HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>;
|
||||
|
||||
|
||||
// Division is performed by normalizing the denominator by multiplying by the
|
||||
// conjugate of the rhs.
|
||||
// numerator = lhs * conj(rhs)
|
||||
// denominator = rhs * conj(rhs)
|
||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions),
|
||||
(HLO_DivOp
|
||||
(HLO_MulOp:$num $lhs,
|
||||
(HLO_ComplexOp:$conj
|
||||
(HLO_RealOp $rhs),
|
||||
(HLO_NegOp (HLO_ImagOp $rhs))),
|
||||
$broadcast_dimensions),
|
||||
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)),
|
||||
(BinBroadcastDimensions $num, $den))>;
|
||||
|
||||
|
||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions),
|
||||
(HLO_ComplexOp
|
||||
(HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions),
|
||||
(HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>;
|
||||
|
||||
|
||||
// Absolute value is evaluated as:
|
||||
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
||||
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
||||
(HLO_ComplexOp
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp (HLO_RealOp:$real $val), $real,
|
||||
(NullDenseIntElementsAttr)),
|
||||
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag,
|
||||
(NullDenseIntElementsAttr)),
|
||||
(NullDenseIntElementsAttr))),
|
||||
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
|
||||
|
||||
// Expononetial can be lowered to an exponential on the real component and a
|
||||
// sum of sinusoids of the imageinary component, which equates to a normal
|
||||
// exponential operator multiplied by Euler's formula.
|
||||
//
|
||||
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
|
||||
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
||||
(HLO_MulOp
|
||||
(HLO_ExpOp (HLO_RealOp $val)),
|
||||
(HLO_ComplexOp
|
||||
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
||||
(HLO_SinOp $imag)),
|
||||
(NullDenseIntElementsAttr))>;
|
@ -186,5 +186,5 @@ void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeGeneralDot> legalize_pass(
|
||||
"xla-lower-general-dot",
|
||||
"Lower a general dot to a non-batched dot when possible");
|
||||
"test-xla-lower-general-dot",
|
||||
"Tests lowering general dot to a non-batched dot when possible");
|
||||
|
@ -28,6 +28,11 @@ namespace xla_hlo {
|
||||
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
// Collection of rewrite patterns for lowering complex operations to equivalent
|
||||
// float operations.
|
||||
void PopulateComplexLoweringPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user