[TF/MLIR] Adds legalization rule for xla_hlo.dot_general.
An xla_hlo.dot_general op will be converted to tf.BatchMatMulV2 op. However, we also need to insert some transpose ops to order batch/contracting/out dimensions properly and then flatten the contracting/out dimensions because BatchMatMul does not support multiple contracting dimensions. PiperOrigin-RevId: 315293215 Change-Id: Iceb3738025e5c8a730340807b1d6d17c10d7ecc2
This commit is contained in:
parent
b8bd7b3483
commit
bf1b3d7e70
@ -723,6 +723,11 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
|
||||
return %0 : tensor<3x8x8x16xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
|
||||
return %0 : tensor<3x5x1x4xf32>
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC(
|
||||
@ -1596,3 +1601,14 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
|
||||
// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
|
||||
// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @convert_dot_general(
|
||||
// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||
// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32>
|
||||
// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32>
|
||||
// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32>
|
||||
// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32>
|
||||
// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
|
||||
// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32>
|
||||
// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32>
|
||||
// CHECK: }
|
||||
|
@ -15,10 +15,15 @@ limitations under the License.
|
||||
|
||||
// This file implements logic for legalizing HLO to TensorFlow.
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -40,6 +45,8 @@ namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
using xla_hlo::DotDimensionNumbers;
|
||||
|
||||
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -75,6 +82,205 @@ class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
||||
};
|
||||
};
|
||||
|
||||
// Appends all elements in `range` to `values`.
|
||||
template <typename ValueT, typename Range>
|
||||
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
|
||||
values.insert(values.end(), range.begin(), range.end());
|
||||
}
|
||||
|
||||
// Appends all elements in `range` to `values`.
|
||||
template <typename ValueT, typename Range, typename... RangeTs>
|
||||
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
|
||||
RangeTs &&... ranges) {
|
||||
values.insert(values.end(), range.begin(), range.end());
|
||||
Append(values, ranges...);
|
||||
}
|
||||
|
||||
// Returns the number of elements in `range`.
|
||||
template <typename Range>
|
||||
size_t Size(Range &&range) {
|
||||
return range.size();
|
||||
}
|
||||
|
||||
// Returns the total number of elements in a variadic number of `ranges`.
|
||||
template <typename Range, typename... RangeTs>
|
||||
size_t Size(Range &&range, RangeTs &&... ranges) {
|
||||
return range.size() + Size(std::forward<RangeTs>(ranges)...);
|
||||
}
|
||||
|
||||
// Concats all elements in `ranges` and returns a small vector as a result.
|
||||
template <typename ValueT, typename... RangeTs>
|
||||
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&... ranges) {
|
||||
llvm::SmallVector<int64_t, 4> results;
|
||||
results.reserve(Size(std::forward<RangeTs>(ranges)...));
|
||||
Append(results, std::forward<RangeTs>(ranges)...);
|
||||
return results;
|
||||
}
|
||||
|
||||
// A struct to hold axes and sizes for a set of dimensions.
|
||||
struct DimensionSetVector {
|
||||
llvm::ArrayRef<int64_t> AxesArray() const { return axes.getArrayRef(); }
|
||||
llvm::ArrayRef<int64_t> SizesArray() const { return sizes.getArrayRef(); }
|
||||
|
||||
llvm::SmallSetVector<int64_t, 4> axes;
|
||||
llvm::SmallSetVector<int64_t, 4> sizes;
|
||||
};
|
||||
|
||||
// A struct to hold information about dimensions of dot_general operands.
|
||||
class DotDimensionsInfo {
|
||||
public:
|
||||
DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
|
||||
DenseIntElementsAttr contracting_dimensions) {
|
||||
const int rank = type.getRank();
|
||||
for (const int dim : batch_dimensions.getValues<int64_t>()) {
|
||||
batch_dimensions_.axes.insert(dim);
|
||||
batch_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||
}
|
||||
|
||||
for (const int dim : contracting_dimensions.getValues<int64_t>()) {
|
||||
contracting_dimensions_.axes.insert(dim);
|
||||
contracting_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||
}
|
||||
|
||||
for (int dim = 0; dim < rank; ++dim) {
|
||||
if (contracting_dimensions_.axes.count(dim) > 0 ||
|
||||
batch_dimensions_.axes.count(dim) > 0) {
|
||||
continue;
|
||||
}
|
||||
out_dimensions_.axes.insert(dim);
|
||||
out_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||
}
|
||||
}
|
||||
|
||||
const DimensionSetVector &batch_dimensions() const {
|
||||
return batch_dimensions_;
|
||||
}
|
||||
const DimensionSetVector &contracting_dimensions() const {
|
||||
return contracting_dimensions_;
|
||||
}
|
||||
// Out dimensions are any dimensions that are neither batch nor contracting
|
||||
// dimensions, hence will be propagated to output shape.
|
||||
const DimensionSetVector &out_dimensions() const { return out_dimensions_; }
|
||||
|
||||
// Returns the total dimension size after flattening all contracting
|
||||
// dimensions.
|
||||
int FlattenedContractingDimensionSize() const {
|
||||
return std::accumulate(contracting_dimensions_.sizes.begin(),
|
||||
contracting_dimensions_.sizes.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
// Returns the total dimension size after flattening all out dimensions.
|
||||
int FlattenedOutDimensionSize() const {
|
||||
return std::accumulate(out_dimensions_.sizes.begin(),
|
||||
out_dimensions_.sizes.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
private:
|
||||
DimensionSetVector batch_dimensions_;
|
||||
DimensionSetVector contracting_dimensions_;
|
||||
// Out dimensions are any dimensions that are neither batch nor contracting
|
||||
// dimensions, hence will be propagated to output shape.
|
||||
DimensionSetVector out_dimensions_;
|
||||
};
|
||||
|
||||
// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
|
||||
// inserted to convert to well-formed matrix multiply.
|
||||
Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
auto dot_general_op = cast<xla_hlo::DotGeneralOp>(old_op);
|
||||
auto lhs_type = dot_general_op.lhs().getType().cast<ShapedType>();
|
||||
auto rhs_type = dot_general_op.rhs().getType().cast<ShapedType>();
|
||||
auto result_type = dot_general_op.getResult().getType().cast<ShapedType>();
|
||||
DotDimensionNumbers dot_dimension_numbers =
|
||||
dot_general_op.dot_dimension_numbers();
|
||||
mlir::Location loc = dot_general_op.getLoc();
|
||||
const int lhs_rank = lhs_type.getRank();
|
||||
const int rhs_rank = rhs_type.getRank();
|
||||
|
||||
// Collects lhs and rhs dimensions information.
|
||||
DotDimensionsInfo lhs_dot_dimensions_info(
|
||||
lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
|
||||
dot_dimension_numbers.lhs_contracting_dimensions());
|
||||
DotDimensionsInfo rhs_dot_dimensions_info(
|
||||
rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
|
||||
dot_dimension_numbers.rhs_contracting_dimensions());
|
||||
|
||||
// Transposes lhs shape to be in the order of {batch_dimensions,
|
||||
// out_dimensions, contracting dimensions}.
|
||||
llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
|
||||
lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
|
||||
lhs_dot_dimensions_info.out_dimensions().AxesArray(),
|
||||
lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
|
||||
llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
|
||||
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
lhs_dot_dimensions_info.out_dimensions().SizesArray(),
|
||||
lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
|
||||
auto lhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
|
||||
dot_general_op.lhs(),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
|
||||
lhs_permutation));
|
||||
|
||||
// Transposes rhs shape to be in the order of {batch_dimensions, contracting
|
||||
// dimensions, out_dimensions}.
|
||||
llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
|
||||
rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
|
||||
rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
|
||||
rhs_dot_dimensions_info.out_dimensions().AxesArray());
|
||||
llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
|
||||
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
|
||||
rhs_dot_dimensions_info.out_dimensions().SizesArray());
|
||||
auto rhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
|
||||
dot_general_op.rhs(),
|
||||
DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
|
||||
rhs_permutation));
|
||||
|
||||
// Reshapes lhs to flatten out_dimensions and contracting_dimensions.
|
||||
llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
|
||||
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
llvm::ArrayRef<int64_t>{
|
||||
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
|
||||
auto lhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
|
||||
lhs_transposed.getResult());
|
||||
|
||||
// Reshapes rhs to flatten out_dimensions and contracting_dimensions.
|
||||
llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
|
||||
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
llvm::ArrayRef<int64_t>{
|
||||
rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
|
||||
auto rhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
|
||||
rhs_transposed.getResult());
|
||||
|
||||
// Creates matmul op of `lhs_flattend` and `rhs_flattend`.
|
||||
llvm::SmallVector<int64_t, 4> matmul_shape =
|
||||
Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
llvm::ArrayRef<int64_t>{
|
||||
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
|
||||
auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
|
||||
lhs_flattend.getResult(), rhs_flattend.getResult());
|
||||
auto reshaped =
|
||||
rewriter.create<xla_hlo::ReshapeOp>(loc, result_type, matmul.getResult());
|
||||
return reshaped.getResult();
|
||||
}
|
||||
|
||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||
public:
|
||||
LegalizeHloToTf() = default;
|
||||
|
@ -184,3 +184,10 @@ def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, "
|
||||
def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs,
|
||||
AnyStaticShapeTensor:$rhs, $precision_config),
|
||||
(ConvertDotOp $old_value)>;
|
||||
|
||||
def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, "
|
||||
"$0.getDefiningOp())">;
|
||||
def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
|
||||
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
|
||||
$precision_config),
|
||||
(ConvertDotGeneralOp $old_value)>;
|
||||
|
Loading…
Reference in New Issue
Block a user