Add dynamic_broadcast_in_dim -> broadcast_in_dim canonicalization.
Also, address comment from cl/302763113 PiperOrigin-RevId: 302969433 Change-Id: If1ecd554d89e3a381be998522f2ff97ff52ad9f4
This commit is contained in:
parent
35a3591b3e
commit
5d169b40e9
tensorflow/compiler/mlir/xla
@ -598,6 +598,28 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
|
||||
// BroadcastInDimOp.
|
||||
class DynamicBroadcastInDimOpNotActuallyDynamic
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type || !type.hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(op, "requires static shape");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
|
||||
op, op.getType(), op.operand(), op.broadcast_dimensions());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ClampOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -802,6 +802,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
// Cannot be exported to legacy formats.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
@ -31,6 +31,14 @@ func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1:
|
||||
return %1 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
|
||||
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
|
||||
// CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
|
||||
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @complex_expand_fold
|
||||
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
||||
|
@ -941,7 +941,7 @@ class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
|
||||
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
|
||||
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type || !output_type) {
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op, "requires ranked shape");
|
||||
}
|
||||
auto rank_diff = output_type.getRank() - input_type.getRank();
|
||||
// The tf.BroadcastTo op performs "right-aligned" numpy-style broadcasting.
|
||||
|
Loading…
Reference in New Issue
Block a user