[TF/MLIR] Adds legalization rule for xla_hlo.dot_general.
An xla_hlo.dot_general op will be converted to tf.BatchMatMulV2 op. However, we also need to insert some transpose ops to order batch/contracting/out dimensions properly and then flatten the contracting/out dimensions because BatchMatMul does not support multiple contracting dimensions. PiperOrigin-RevId: 315293215 Change-Id: Iceb3738025e5c8a730340807b1d6d17c10d7ecc2
This commit is contained in:
parent
b8bd7b3483
commit
bf1b3d7e70
@ -723,6 +723,11 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
|
|||||||
return %0 : tensor<3x8x8x16xf32>
|
return %0 : tensor<3x8x8x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||||
|
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
|
||||||
|
return %0 : tensor<3x5x1x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||||
|
|
||||||
// CHECK-LABEL: func @biasAdd_NHWC(
|
// CHECK-LABEL: func @biasAdd_NHWC(
|
||||||
@ -1596,3 +1601,14 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
|
|||||||
// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
|
// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
|
||||||
// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32>
|
// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_dot_general(
|
||||||
|
// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||||
|
// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32>
|
||||||
|
// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32>
|
||||||
|
// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32>
|
||||||
|
// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32>
|
||||||
|
// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
|
||||||
|
// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32>
|
||||||
|
// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32>
|
||||||
|
// CHECK: }
|
||||||
|
@ -15,10 +15,15 @@ limitations under the License.
|
|||||||
|
|
||||||
// This file implements logic for legalizing HLO to TensorFlow.
|
// This file implements logic for legalizing HLO to TensorFlow.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
@ -40,6 +45,8 @@ namespace mlir {
|
|||||||
namespace TF {
|
namespace TF {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using xla_hlo::DotDimensionNumbers;
|
||||||
|
|
||||||
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
@ -75,6 +82,205 @@ class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Appends all elements in `range` to `values`.
|
||||||
|
template <typename ValueT, typename Range>
|
||||||
|
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
|
||||||
|
values.insert(values.end(), range.begin(), range.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Appends all elements in `range` to `values`.
|
||||||
|
template <typename ValueT, typename Range, typename... RangeTs>
|
||||||
|
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
|
||||||
|
RangeTs &&... ranges) {
|
||||||
|
values.insert(values.end(), range.begin(), range.end());
|
||||||
|
Append(values, ranges...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of elements in `range`.
|
||||||
|
template <typename Range>
|
||||||
|
size_t Size(Range &&range) {
|
||||||
|
return range.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the total number of elements in a variadic number of `ranges`.
|
||||||
|
template <typename Range, typename... RangeTs>
|
||||||
|
size_t Size(Range &&range, RangeTs &&... ranges) {
|
||||||
|
return range.size() + Size(std::forward<RangeTs>(ranges)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concats all elements in `ranges` and returns a small vector as a result.
|
||||||
|
template <typename ValueT, typename... RangeTs>
|
||||||
|
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&... ranges) {
|
||||||
|
llvm::SmallVector<int64_t, 4> results;
|
||||||
|
results.reserve(Size(std::forward<RangeTs>(ranges)...));
|
||||||
|
Append(results, std::forward<RangeTs>(ranges)...);
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A struct to hold axes and sizes for a set of dimensions.
|
||||||
|
struct DimensionSetVector {
|
||||||
|
llvm::ArrayRef<int64_t> AxesArray() const { return axes.getArrayRef(); }
|
||||||
|
llvm::ArrayRef<int64_t> SizesArray() const { return sizes.getArrayRef(); }
|
||||||
|
|
||||||
|
llvm::SmallSetVector<int64_t, 4> axes;
|
||||||
|
llvm::SmallSetVector<int64_t, 4> sizes;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A struct to hold information about dimensions of dot_general operands.
|
||||||
|
class DotDimensionsInfo {
|
||||||
|
public:
|
||||||
|
DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
|
||||||
|
DenseIntElementsAttr contracting_dimensions) {
|
||||||
|
const int rank = type.getRank();
|
||||||
|
for (const int dim : batch_dimensions.getValues<int64_t>()) {
|
||||||
|
batch_dimensions_.axes.insert(dim);
|
||||||
|
batch_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const int dim : contracting_dimensions.getValues<int64_t>()) {
|
||||||
|
contracting_dimensions_.axes.insert(dim);
|
||||||
|
contracting_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int dim = 0; dim < rank; ++dim) {
|
||||||
|
if (contracting_dimensions_.axes.count(dim) > 0 ||
|
||||||
|
batch_dimensions_.axes.count(dim) > 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out_dimensions_.axes.insert(dim);
|
||||||
|
out_dimensions_.sizes.insert(type.getDimSize(dim));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const DimensionSetVector &batch_dimensions() const {
|
||||||
|
return batch_dimensions_;
|
||||||
|
}
|
||||||
|
const DimensionSetVector &contracting_dimensions() const {
|
||||||
|
return contracting_dimensions_;
|
||||||
|
}
|
||||||
|
// Out dimensions are any dimensions that are neither batch nor contracting
|
||||||
|
// dimensions, hence will be propagated to output shape.
|
||||||
|
const DimensionSetVector &out_dimensions() const { return out_dimensions_; }
|
||||||
|
|
||||||
|
// Returns the total dimension size after flattening all contracting
|
||||||
|
// dimensions.
|
||||||
|
int FlattenedContractingDimensionSize() const {
|
||||||
|
return std::accumulate(contracting_dimensions_.sizes.begin(),
|
||||||
|
contracting_dimensions_.sizes.end(), 1,
|
||||||
|
std::multiplies<int64_t>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the total dimension size after flattening all out dimensions.
|
||||||
|
int FlattenedOutDimensionSize() const {
|
||||||
|
return std::accumulate(out_dimensions_.sizes.begin(),
|
||||||
|
out_dimensions_.sizes.end(), 1,
|
||||||
|
std::multiplies<int64_t>());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DimensionSetVector batch_dimensions_;
|
||||||
|
DimensionSetVector contracting_dimensions_;
|
||||||
|
// Out dimensions are any dimensions that are neither batch nor contracting
|
||||||
|
// dimensions, hence will be propagated to output shape.
|
||||||
|
DimensionSetVector out_dimensions_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
|
||||||
|
// inserted to convert to well-formed matrix multiply.
|
||||||
|
Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||||
|
auto dot_general_op = cast<xla_hlo::DotGeneralOp>(old_op);
|
||||||
|
auto lhs_type = dot_general_op.lhs().getType().cast<ShapedType>();
|
||||||
|
auto rhs_type = dot_general_op.rhs().getType().cast<ShapedType>();
|
||||||
|
auto result_type = dot_general_op.getResult().getType().cast<ShapedType>();
|
||||||
|
DotDimensionNumbers dot_dimension_numbers =
|
||||||
|
dot_general_op.dot_dimension_numbers();
|
||||||
|
mlir::Location loc = dot_general_op.getLoc();
|
||||||
|
const int lhs_rank = lhs_type.getRank();
|
||||||
|
const int rhs_rank = rhs_type.getRank();
|
||||||
|
|
||||||
|
// Collects lhs and rhs dimensions information.
|
||||||
|
DotDimensionsInfo lhs_dot_dimensions_info(
|
||||||
|
lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
|
||||||
|
dot_dimension_numbers.lhs_contracting_dimensions());
|
||||||
|
DotDimensionsInfo rhs_dot_dimensions_info(
|
||||||
|
rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
|
||||||
|
dot_dimension_numbers.rhs_contracting_dimensions());
|
||||||
|
|
||||||
|
// Transposes lhs shape to be in the order of {batch_dimensions,
|
||||||
|
// out_dimensions, contracting dimensions}.
|
||||||
|
llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
|
||||||
|
lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
|
||||||
|
lhs_dot_dimensions_info.out_dimensions().AxesArray(),
|
||||||
|
lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
|
||||||
|
llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
|
||||||
|
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||||
|
lhs_dot_dimensions_info.out_dimensions().SizesArray(),
|
||||||
|
lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
|
||||||
|
auto lhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
|
||||||
|
dot_general_op.lhs(),
|
||||||
|
DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
|
||||||
|
lhs_permutation));
|
||||||
|
|
||||||
|
// Transposes rhs shape to be in the order of {batch_dimensions, contracting
|
||||||
|
// dimensions, out_dimensions}.
|
||||||
|
llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
|
||||||
|
rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
|
||||||
|
rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
|
||||||
|
rhs_dot_dimensions_info.out_dimensions().AxesArray());
|
||||||
|
llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
|
||||||
|
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||||
|
rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
|
||||||
|
rhs_dot_dimensions_info.out_dimensions().SizesArray());
|
||||||
|
auto rhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
|
||||||
|
dot_general_op.rhs(),
|
||||||
|
DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
|
||||||
|
rhs_permutation));
|
||||||
|
|
||||||
|
// Reshapes lhs to flatten out_dimensions and contracting_dimensions.
|
||||||
|
llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
|
||||||
|
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||||
|
llvm::ArrayRef<int64_t>{
|
||||||
|
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
|
||||||
|
llvm::ArrayRef<int64_t>{
|
||||||
|
lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
|
||||||
|
auto lhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
|
||||||
|
lhs_transposed.getResult());
|
||||||
|
|
||||||
|
// Reshapes rhs to flatten out_dimensions and contracting_dimensions.
|
||||||
|
llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
|
||||||
|
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||||
|
llvm::ArrayRef<int64_t>{
|
||||||
|
rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
|
||||||
|
llvm::ArrayRef<int64_t>{
|
||||||
|
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
|
||||||
|
auto rhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
|
||||||
|
rhs_transposed.getResult());
|
||||||
|
|
||||||
|
// Creates matmul op of `lhs_flattend` and `rhs_flattend`.
|
||||||
|
llvm::SmallVector<int64_t, 4> matmul_shape =
|
||||||
|
Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||||
|
llvm::ArrayRef<int64_t>{
|
||||||
|
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
|
||||||
|
llvm::ArrayRef<int64_t>{
|
||||||
|
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
|
||||||
|
auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
|
||||||
|
loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
|
||||||
|
lhs_flattend.getResult(), rhs_flattend.getResult());
|
||||||
|
auto reshaped =
|
||||||
|
rewriter.create<xla_hlo::ReshapeOp>(loc, result_type, matmul.getResult());
|
||||||
|
return reshaped.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
LegalizeHloToTf() = default;
|
LegalizeHloToTf() = default;
|
||||||
|
@ -184,3 +184,10 @@ def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, "
|
|||||||
def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs,
|
def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs,
|
||||||
AnyStaticShapeTensor:$rhs, $precision_config),
|
AnyStaticShapeTensor:$rhs, $precision_config),
|
||||||
(ConvertDotOp $old_value)>;
|
(ConvertDotOp $old_value)>;
|
||||||
|
|
||||||
|
def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, "
|
||||||
|
"$0.getDefiningOp())">;
|
||||||
|
def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
|
||||||
|
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
|
||||||
|
$precision_config),
|
||||||
|
(ConvertDotGeneralOp $old_value)>;
|
||||||
|
Loading…
Reference in New Issue
Block a user