From bf1b3d7e70e39672ab35ae7e04a15fa082d4b8e1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jun 2020 09:45:23 -0700 Subject: [PATCH] [TF/MLIR] Adds legalization rule for xla_hlo.dot_general. An xla_hlo.dot_general op will be converted to tf.BatchMatMulV2 op. However, we also need to insert some transpose ops to order batch/contracting/out dimensions properly and then flatten the contracting/out dimensions because BatchMatMul does not support multiple contracting dimensions. PiperOrigin-RevId: 315293215 Change-Id: Iceb3738025e5c8a730340807b1d6d17c10d7ecc2 --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 16 ++ .../tensorflow/transforms/legalize_hlo.cc | 206 ++++++++++++++++++ .../transforms/legalize_hlo_patterns.td | 7 + 3 files changed, 229 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 00e35460f20..2b4f88a3524 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -723,6 +723,11 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1 return %0 : tensor<3x8x8x16xf32> } +func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { + %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> + return %0 : tensor<3x5x1x4xf32> +} + // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK-LABEL: func @biasAdd_NHWC( @@ -1596,3 +1601,14 @@ func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x1 // CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32> // CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32> // CHECK: } + +// CHECK-LABEL: func @convert_dot_general( +// CHECK-SAME: [[VAL_396:%.*]]: tensor<3x2x6x5x1xf32>, [[VAL_397:%.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { +// CHECK: [[VAL_398:%.*]] = "tf.Transpose"([[VAL_396]], {{.*}}) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32> +// CHECK: [[VAL_399:%.*]] = "tf.Transpose"([[VAL_397]], {{.*}}) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32> +// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], {{.*}}) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32> +// CHECK: [[VAL_401:%.*]] = "tf.Reshape"([[VAL_399]], {{.*}}) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32> +// CHECK: [[VAL_402:%.*]] = "tf.BatchMatMulV2"([[VAL_400]], [[VAL_401]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK: [[VAL_403:%.*]] = "tf.Reshape"([[VAL_402]], {{.*}}) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32> +// CHECK: return [[VAL_403]] : tensor<3x5x1x4xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index f6c00e8cb82..267819f6c9a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -15,10 +15,15 @@ limitations under the License. // This file implements logic for legalizing HLO to TensorFlow. +#include +#include #include +#include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -40,6 +45,8 @@ namespace mlir { namespace TF { namespace { +using xla_hlo::DotDimensionNumbers; + class ConvertSliceOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -75,6 +82,205 @@ class ConvertSliceOp : public OpConversionPattern { }; }; +// Appends all elements in `range` to `values`. +template +void Append(llvm::SmallVectorImpl &values, Range &&range) { + values.insert(values.end(), range.begin(), range.end()); +} + +// Appends all elements in `range` to `values`. +template +void Append(llvm::SmallVectorImpl &values, Range &&range, + RangeTs &&... ranges) { + values.insert(values.end(), range.begin(), range.end()); + Append(values, ranges...); +} + +// Returns the number of elements in `range`. +template +size_t Size(Range &&range) { + return range.size(); +} + +// Returns the total number of elements in a variadic number of `ranges`. +template +size_t Size(Range &&range, RangeTs &&... ranges) { + return range.size() + Size(std::forward(ranges)...); +} + +// Concats all elements in `ranges` and returns a small vector as a result. +template +llvm::SmallVector Concat(RangeTs &&... ranges) { + llvm::SmallVector results; + results.reserve(Size(std::forward(ranges)...)); + Append(results, std::forward(ranges)...); + return results; +} + +// A struct to hold axes and sizes for a set of dimensions. +struct DimensionSetVector { + llvm::ArrayRef AxesArray() const { return axes.getArrayRef(); } + llvm::ArrayRef SizesArray() const { return sizes.getArrayRef(); } + + llvm::SmallSetVector axes; + llvm::SmallSetVector sizes; +}; + +// A struct to hold information about dimensions of dot_general operands. +class DotDimensionsInfo { + public: + DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions, + DenseIntElementsAttr contracting_dimensions) { + const int rank = type.getRank(); + for (const int dim : batch_dimensions.getValues()) { + batch_dimensions_.axes.insert(dim); + batch_dimensions_.sizes.insert(type.getDimSize(dim)); + } + + for (const int dim : contracting_dimensions.getValues()) { + contracting_dimensions_.axes.insert(dim); + contracting_dimensions_.sizes.insert(type.getDimSize(dim)); + } + + for (int dim = 0; dim < rank; ++dim) { + if (contracting_dimensions_.axes.count(dim) > 0 || + batch_dimensions_.axes.count(dim) > 0) { + continue; + } + out_dimensions_.axes.insert(dim); + out_dimensions_.sizes.insert(type.getDimSize(dim)); + } + } + + const DimensionSetVector &batch_dimensions() const { + return batch_dimensions_; + } + const DimensionSetVector &contracting_dimensions() const { + return contracting_dimensions_; + } + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + const DimensionSetVector &out_dimensions() const { return out_dimensions_; } + + // Returns the total dimension size after flattening all contracting + // dimensions. + int FlattenedContractingDimensionSize() const { + return std::accumulate(contracting_dimensions_.sizes.begin(), + contracting_dimensions_.sizes.end(), 1, + std::multiplies()); + } + + // Returns the total dimension size after flattening all out dimensions. + int FlattenedOutDimensionSize() const { + return std::accumulate(out_dimensions_.sizes.begin(), + out_dimensions_.sizes.end(), 1, + std::multiplies()); + } + + private: + DimensionSetVector batch_dimensions_; + DimensionSetVector contracting_dimensions_; + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + DimensionSetVector out_dimensions_; +}; + +// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be +// inserted to convert to well-formed matrix multiply. +Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { + auto dot_general_op = cast(old_op); + auto lhs_type = dot_general_op.lhs().getType().cast(); + auto rhs_type = dot_general_op.rhs().getType().cast(); + auto result_type = dot_general_op.getResult().getType().cast(); + DotDimensionNumbers dot_dimension_numbers = + dot_general_op.dot_dimension_numbers(); + mlir::Location loc = dot_general_op.getLoc(); + const int lhs_rank = lhs_type.getRank(); + const int rhs_rank = rhs_type.getRank(); + + // Collects lhs and rhs dimensions information. + DotDimensionsInfo lhs_dot_dimensions_info( + lhs_type, dot_dimension_numbers.lhs_batching_dimensions(), + dot_dimension_numbers.lhs_contracting_dimensions()); + DotDimensionsInfo rhs_dot_dimensions_info( + rhs_type, dot_dimension_numbers.rhs_batching_dimensions(), + dot_dimension_numbers.rhs_contracting_dimensions()); + + // Transposes lhs shape to be in the order of {batch_dimensions, + // out_dimensions, contracting dimensions}. + llvm::SmallVector lhs_permutation = Concat( + lhs_dot_dimensions_info.batch_dimensions().AxesArray(), + lhs_dot_dimensions_info.out_dimensions().AxesArray(), + lhs_dot_dimensions_info.contracting_dimensions().AxesArray()); + llvm::SmallVector lhs_transposed_shape = Concat( + lhs_dot_dimensions_info.batch_dimensions().SizesArray(), + lhs_dot_dimensions_info.out_dimensions().SizesArray(), + lhs_dot_dimensions_info.contracting_dimensions().SizesArray()); + auto lhs_transposed = rewriter.create( + loc, + RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()), + dot_general_op.lhs(), + DenseIntElementsAttr::get( + RankedTensorType::get({lhs_rank}, rewriter.getI64Type()), + lhs_permutation)); + + // Transposes rhs shape to be in the order of {batch_dimensions, contracting + // dimensions, out_dimensions}. + llvm::SmallVector rhs_permutation = Concat( + rhs_dot_dimensions_info.batch_dimensions().AxesArray(), + rhs_dot_dimensions_info.contracting_dimensions().AxesArray(), + rhs_dot_dimensions_info.out_dimensions().AxesArray()); + llvm::SmallVector rhs_transposed_shape = Concat( + rhs_dot_dimensions_info.batch_dimensions().SizesArray(), + rhs_dot_dimensions_info.contracting_dimensions().SizesArray(), + rhs_dot_dimensions_info.out_dimensions().SizesArray()); + auto rhs_transposed = rewriter.create( + loc, + RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()), + dot_general_op.rhs(), + DenseIntElementsAttr::get( + RankedTensorType::get({rhs_rank}, rewriter.getI64Type()), + rhs_permutation)); + + // Reshapes lhs to flatten out_dimensions and contracting_dimensions. + llvm::SmallVector lhs_flattened_shape = Concat( + lhs_dot_dimensions_info.batch_dimensions().SizesArray(), + llvm::ArrayRef{ + lhs_dot_dimensions_info.FlattenedOutDimensionSize()}, + llvm::ArrayRef{ + lhs_dot_dimensions_info.FlattenedContractingDimensionSize()}); + auto lhs_flattend = rewriter.create( + loc, + RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()), + lhs_transposed.getResult()); + + // Reshapes rhs to flatten out_dimensions and contracting_dimensions. + llvm::SmallVector rhs_flattened_shape = Concat( + rhs_dot_dimensions_info.batch_dimensions().SizesArray(), + llvm::ArrayRef{ + rhs_dot_dimensions_info.FlattenedContractingDimensionSize()}, + llvm::ArrayRef{ + rhs_dot_dimensions_info.FlattenedOutDimensionSize()}); + auto rhs_flattend = rewriter.create( + loc, + RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()), + rhs_transposed.getResult()); + + // Creates matmul op of `lhs_flattend` and `rhs_flattend`. + llvm::SmallVector matmul_shape = + Concat(lhs_dot_dimensions_info.batch_dimensions().SizesArray(), + llvm::ArrayRef{ + lhs_dot_dimensions_info.FlattenedOutDimensionSize()}, + llvm::ArrayRef{ + rhs_dot_dimensions_info.FlattenedOutDimensionSize()}); + auto matmul = rewriter.create( + loc, RankedTensorType::get(matmul_shape, result_type.getElementType()), + lhs_flattend.getResult(), rhs_flattend.getResult()); + auto reshaped = + rewriter.create(loc, result_type, matmul.getResult()); + return reshaped.getResult(); +} + class LegalizeHloToTf : public PassWrapper { public: LegalizeHloToTf() = default; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index df78aa97f01..3e910cd9512 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -184,3 +184,10 @@ def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, " def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs, AnyStaticShapeTensor:$rhs, $precision_config), (ConvertDotOp $old_value)>; + +def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, " + "$0.getDefiningOp())">; +def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs, + AnyStaticShapeTensor:$rhs, $dot_dimension_numbers, + $precision_config), + (ConvertDotGeneralOp $old_value)>;