Add dynamic_broadcast_in_dim -> broadcast_in_dim canonicalization.

Also, address comment from cl/302763113

PiperOrigin-RevId: 302969433
Change-Id: If1ecd554d89e3a381be998522f2ff97ff52ad9f4
This commit is contained in:
Sean Silva 2020-03-25 14:24:54 -07:00 committed by TensorFlower Gardener
parent 35a3591b3e
commit 5d169b40e9
4 changed files with 32 additions and 1 deletions
tensorflow/compiler/mlir/xla

View File

@ -598,6 +598,28 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
return success();
}
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
// BroadcastInDimOp.
class DynamicBroadcastInDimOpNotActuallyDynamic
: public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
PatternRewriter& rewriter) const override {
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type || !type.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "requires static shape");
}
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
op, op.getType(), op.operand(), op.broadcast_dimensions());
return success();
}
};
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic>(context);
}
//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//

View File

@ -802,6 +802,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
let results = (outs HLO_Tensor);
let hasCanonicalizer = 1;
// Cannot be exported to legacy formats.
let hasCustomHLOConverter = 1;
}

View File

@ -31,6 +31,14 @@ func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1:
return %1 : tensor<1x4xi32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
// CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: @complex_expand_fold
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)

View File

@ -941,7 +941,7 @@ class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) {
return failure();
return rewriter.notifyMatchFailure(op, "requires ranked shape");
}
auto rank_diff = output_type.getRank() - input_type.getRank();
// The tf.BroadcastTo op performs "right-aligned" numpy-style broadcasting.