[TF/MLIR] Adds legalization rule for xla_hlo.dot_general.

An xla_hlo.dot_general op will be converted to tf.BatchMatMulV2 op. However, we also need to insert some transpose ops to order batch/contracting/out dimensions properly and then flatten the contracting/out dimensions because BatchMatMul does not support multiple contracting dimensions.

PiperOrigin-RevId: 315293215
Change-Id: Iceb3738025e5c8a730340807b1d6d17c10d7ecc2
This commit is contained in:
A. Unique TensorFlower 2020-06-08 09:45:23 -07:00 committed by TensorFlower Gardener
parent b8bd7b3483
commit bf1b3d7e70
3 changed files with 229 additions and 0 deletions

View File

@ -723,6 +723,11 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
return %0 : tensor<3x8x8x16xf32>
}
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
return %0 : tensor<3x5x1x4xf32>
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @biasAdd_NHWC(
@ -1596,3 +1601,14 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1
// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32>
// CHECK: }
// CHECK-LABEL: func @convert_dot_general(
// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32>
// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32>
// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32>
// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32>
// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32>
// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32>
// CHECK: }

View File

@ -15,10 +15,15 @@ limitations under the License.
// This file implements logic for legalizing HLO to TensorFlow.
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -40,6 +45,8 @@ namespace mlir {
namespace TF {
namespace {
using xla_hlo::DotDimensionNumbers;
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -75,6 +82,205 @@ class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
};
};
// Appends all elements in `range` to `values`.
template <typename ValueT, typename Range>
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
values.insert(values.end(), range.begin(), range.end());
}
// Appends all elements in `range` to `values`.
template <typename ValueT, typename Range, typename... RangeTs>
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
RangeTs &&... ranges) {
values.insert(values.end(), range.begin(), range.end());
Append(values, ranges...);
}
// Returns the number of elements in `range`.
template <typename Range>
size_t Size(Range &&range) {
return range.size();
}
// Returns the total number of elements in a variadic number of `ranges`.
template <typename Range, typename... RangeTs>
size_t Size(Range &&range, RangeTs &&... ranges) {
return range.size() + Size(std::forward<RangeTs>(ranges)...);
}
// Concats all elements in `ranges` and returns a small vector as a result.
template <typename ValueT, typename... RangeTs>
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&... ranges) {
llvm::SmallVector<int64_t, 4> results;
results.reserve(Size(std::forward<RangeTs>(ranges)...));
Append(results, std::forward<RangeTs>(ranges)...);
return results;
}
// A struct to hold axes and sizes for a set of dimensions.
struct DimensionSetVector {
llvm::ArrayRef<int64_t> AxesArray() const { return axes.getArrayRef(); }
llvm::ArrayRef<int64_t> SizesArray() const { return sizes.getArrayRef(); }
llvm::SmallSetVector<int64_t, 4> axes;
llvm::SmallSetVector<int64_t, 4> sizes;
};
// A struct to hold information about dimensions of dot_general operands.
class DotDimensionsInfo {
public:
DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
DenseIntElementsAttr contracting_dimensions) {
const int rank = type.getRank();
for (const int dim : batch_dimensions.getValues<int64_t>()) {
batch_dimensions_.axes.insert(dim);
batch_dimensions_.sizes.insert(type.getDimSize(dim));
}
for (const int dim : contracting_dimensions.getValues<int64_t>()) {
contracting_dimensions_.axes.insert(dim);
contracting_dimensions_.sizes.insert(type.getDimSize(dim));
}
for (int dim = 0; dim < rank; ++dim) {
if (contracting_dimensions_.axes.count(dim) > 0 ||
batch_dimensions_.axes.count(dim) > 0) {
continue;
}
out_dimensions_.axes.insert(dim);
out_dimensions_.sizes.insert(type.getDimSize(dim));
}
}
const DimensionSetVector &batch_dimensions() const {
return batch_dimensions_;
}
const DimensionSetVector &contracting_dimensions() const {
return contracting_dimensions_;
}
// Out dimensions are any dimensions that are neither batch nor contracting
// dimensions, hence will be propagated to output shape.
const DimensionSetVector &out_dimensions() const { return out_dimensions_; }
// Returns the total dimension size after flattening all contracting
// dimensions.
int FlattenedContractingDimensionSize() const {
return std::accumulate(contracting_dimensions_.sizes.begin(),
contracting_dimensions_.sizes.end(), 1,
std::multiplies<int64_t>());
}
// Returns the total dimension size after flattening all out dimensions.
int FlattenedOutDimensionSize() const {
return std::accumulate(out_dimensions_.sizes.begin(),
out_dimensions_.sizes.end(), 1,
std::multiplies<int64_t>());
}
private:
DimensionSetVector batch_dimensions_;
DimensionSetVector contracting_dimensions_;
// Out dimensions are any dimensions that are neither batch nor contracting
// dimensions, hence will be propagated to output shape.
DimensionSetVector out_dimensions_;
};
// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
// inserted to convert to well-formed matrix multiply.
Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
auto dot_general_op = cast<xla_hlo::DotGeneralOp>(old_op);
auto lhs_type = dot_general_op.lhs().getType().cast<ShapedType>();
auto rhs_type = dot_general_op.rhs().getType().cast<ShapedType>();
auto result_type = dot_general_op.getResult().getType().cast<ShapedType>();
DotDimensionNumbers dot_dimension_numbers =
dot_general_op.dot_dimension_numbers();
mlir::Location loc = dot_general_op.getLoc();
const int lhs_rank = lhs_type.getRank();
const int rhs_rank = rhs_type.getRank();
// Collects lhs and rhs dimensions information.
DotDimensionsInfo lhs_dot_dimensions_info(
lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
dot_dimension_numbers.lhs_contracting_dimensions());
DotDimensionsInfo rhs_dot_dimensions_info(
rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
dot_dimension_numbers.rhs_contracting_dimensions());
// Transposes lhs shape to be in the order of {batch_dimensions,
// out_dimensions, contracting dimensions}.
llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
lhs_dot_dimensions_info.out_dimensions().AxesArray(),
lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
lhs_dot_dimensions_info.out_dimensions().SizesArray(),
lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
auto lhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
loc,
RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
dot_general_op.lhs(),
DenseIntElementsAttr::get(
RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
lhs_permutation));
// Transposes rhs shape to be in the order of {batch_dimensions, contracting
// dimensions, out_dimensions}.
llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
rhs_dot_dimensions_info.out_dimensions().AxesArray());
llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
rhs_dot_dimensions_info.out_dimensions().SizesArray());
auto rhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
loc,
RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
dot_general_op.rhs(),
DenseIntElementsAttr::get(
RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
rhs_permutation));
// Reshapes lhs to flatten out_dimensions and contracting_dimensions.
llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
llvm::ArrayRef<int64_t>{
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
llvm::ArrayRef<int64_t>{
lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
auto lhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
loc,
RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
lhs_transposed.getResult());
// Reshapes rhs to flatten out_dimensions and contracting_dimensions.
llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
llvm::ArrayRef<int64_t>{
rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
llvm::ArrayRef<int64_t>{
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
auto rhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
loc,
RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
rhs_transposed.getResult());
// Creates matmul op of `lhs_flattend` and `rhs_flattend`.
llvm::SmallVector<int64_t, 4> matmul_shape =
Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
llvm::ArrayRef<int64_t>{
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
llvm::ArrayRef<int64_t>{
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
lhs_flattend.getResult(), rhs_flattend.getResult());
auto reshaped =
rewriter.create<xla_hlo::ReshapeOp>(loc, result_type, matmul.getResult());
return reshaped.getResult();
}
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
public:
LegalizeHloToTf() = default;

View File

@ -184,3 +184,10 @@ def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, "
def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs,
AnyStaticShapeTensor:$rhs, $precision_config),
(ConvertDotOp $old_value)>;
def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, "
"$0.getDefiningOp())">;
def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
$precision_config),
(ConvertDotGeneralOp $old_value)>;