Let UnfuseBatchNorm handle dynamic shapes.
Also add an empty dimensions vector when broadcasting a scalar value. This is needed to legalize the broadcast further down. Also, this follows pre-existing conventions of how broadcasts of scalars are represented. PiperOrigin-RevId: 299297553 Change-Id: I51a64f1d7b4ce3349d3e02e743faf66e29aaead1
This commit is contained in:
parent
0dbec67fe7
commit
53ed9a3339
@ -354,7 +354,9 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ func @batchNormInference_2D_inner_features(
|
||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||
-> (tensor<4x256xf32>) {
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
@ -92,3 +92,46 @@ func @batchNormInference_f16_overflow(
|
||||
tensor<256xf16>) -> tensor<4x256xf16>
|
||||
return %0 : tensor<4x256xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @batchNormInference_dynamic_shape
|
||||
// Validate that dynamic shapes are handled properly.
|
||||
// CHECK-SAME: %[[X:[^:[:space:]]+]]
|
||||
// CHECK-SAME: %[[SCALE:[^:[:space:]]+]]
|
||||
// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]]
|
||||
// CHECK-SAME: %[[MEAN:[^:[:space:]]+]]
|
||||
// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]]
|
||||
func @batchNormInference_dynamic_shape(
|
||||
%x: tensor<?x?x?x?xf32>, %scale: tensor<?xf32>, %offset: tensor<?xf32>,
|
||||
%mean: tensor<?xf32>, %variance: tensor<?xf32>)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor<?xf32>
|
||||
// CHECK-DAG: %[[INDEX_CAST:.+]] = index_cast %[[DIM]] : index to i32
|
||||
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INDEX_CAST]]) : (i32) -> tensor<1xi32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_INDEX_CAST_0:.+]] = index_cast %[[INPUT_DIM_0]] : index to i32
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_INDEX_CAST_1:.+]] = index_cast %[[INPUT_DIM_1]] : index to i32
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_INDEX_CAST_2:.+]] = index_cast %[[INPUT_DIM_2]] : index to i32
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_INDEX_CAST_3:.+]] = index_cast %[[INPUT_DIM_3]] : index to i32
|
||||
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_INDEX_CAST_0]], %[[INPUT_INDEX_CAST_1]], %[[INPUT_INDEX_CAST_2]], %[[INPUT_INDEX_CAST_3]]) : (i32, i32, i32, i32) -> tensor<4xi32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.sub %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.mul %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.div %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
|
||||
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
|
||||
tensor<?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
@ -28,20 +30,47 @@ namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// Broadcasts the 1D value tensor to rank.
|
||||
Value broadcastToFeatureDim(Location loc, Type result_type, Value value_1d,
|
||||
// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
|
||||
// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
|
||||
// a static broadcast.
|
||||
Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
|
||||
Value value_1d, Value shape_value,
|
||||
int64_t feature_dim,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
ConversionPatternRewriter& rewriter) { // NOLINT
|
||||
Builder b(rewriter.getContext());
|
||||
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
|
||||
if (shape_value) {
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
loc, result_type, value_1d, shape_value, dims);
|
||||
}
|
||||
assert(result_type.hasStaticShape());
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||
dims);
|
||||
}
|
||||
|
||||
// Calculate the shape value of operand, assuming it is a dynamic shape with
|
||||
// static rank.
|
||||
Value CalculateShapeValue(Location loc, Value operand,
|
||||
ConversionPatternRewriter& rewriter) { // NOLINT
|
||||
RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
llvm::SmallVector<Value, 4> shape_values;
|
||||
int64_t rank = result_type.getRank();
|
||||
shape_values.reserve(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
auto index_value = rewriter.create<mlir::DimOp>(loc, operand, i);
|
||||
shape_values.push_back(rewriter.create<mlir::IndexCastOp>(
|
||||
loc, index_value, rewriter.getIntegerType(32)));
|
||||
}
|
||||
Type shape_element_type = shape_values.front().getType();
|
||||
return rewriter.create<ScalarsToDimensionTensorOp>(
|
||||
loc, RankedTensorType::get({rank}, shape_element_type), shape_values);
|
||||
}
|
||||
|
||||
Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
|
||||
FloatType fp_type, Type broadcast_to_type,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
FloatType fp_type, Value variance,
|
||||
RankedTensorType broadcast_to_type,
|
||||
ConversionPatternRewriter& rewriter) { // NOLINT
|
||||
Builder b(rewriter.getContext());
|
||||
if (epsilon_attr.getType() != fp_type) {
|
||||
// Need to convert.
|
||||
@ -66,9 +95,16 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
|
||||
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
|
||||
Value epsilon =
|
||||
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||
epsilon = rewriter.create<xla_hlo::BroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/nullptr);
|
||||
return epsilon;
|
||||
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
|
||||
if (broadcast_to_type.hasStaticShape()) {
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
|
||||
}
|
||||
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, shape_value,
|
||||
/*broadcast_dims=*/dims);
|
||||
}
|
||||
|
||||
class UnfuseBatchNormInferencePattern
|
||||
@ -84,9 +120,10 @@ class UnfuseBatchNormInferencePattern
|
||||
// Enforce type invariants.
|
||||
// Note that we deduce the actual element type from the variance,
|
||||
// which should not be subject to quantization at a higher level.
|
||||
auto input_type = operands.operand().getType();
|
||||
auto variance_type = operands.variance().getType().dyn_cast<ShapedType>();
|
||||
if (!variance_type) {
|
||||
auto input_type = operands.operand().getType().dyn_cast<RankedTensorType>();
|
||||
auto variance_type =
|
||||
operands.variance().getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type || !variance_type) {
|
||||
return matchFailure();
|
||||
}
|
||||
auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
|
||||
@ -97,8 +134,9 @@ class UnfuseBatchNormInferencePattern
|
||||
|
||||
// Add epsilon to the variance and sqrt to get stddev:
|
||||
// stddev = sqrt(variance + epsilon)
|
||||
auto epsilon = MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(),
|
||||
fp_type, variance_type, rewriter);
|
||||
auto epsilon =
|
||||
MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
|
||||
operands.variance(), variance_type, rewriter);
|
||||
if (!epsilon) {
|
||||
return matchFailure();
|
||||
}
|
||||
@ -108,14 +146,22 @@ class UnfuseBatchNormInferencePattern
|
||||
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
|
||||
// Broadcast all terms.
|
||||
auto broadcast_scale = broadcastToFeatureDim(
|
||||
bn_op.getLoc(), input_type, operands.scale(), feature_dim, rewriter);
|
||||
auto broadcast_offset = broadcastToFeatureDim(
|
||||
bn_op.getLoc(), input_type, operands.offset(), feature_dim, rewriter);
|
||||
auto broadcast_mean = broadcastToFeatureDim(
|
||||
bn_op.getLoc(), input_type, operands.mean(), feature_dim, rewriter);
|
||||
auto broadcast_stddev = broadcastToFeatureDim(
|
||||
bn_op.getLoc(), input_type, stddev, feature_dim, rewriter);
|
||||
Value shape_value;
|
||||
if (!input_type.hasStaticShape()) {
|
||||
shape_value =
|
||||
CalculateShapeValue(bn_op.getLoc(), operands.operand(), rewriter);
|
||||
}
|
||||
auto broadcast_scale =
|
||||
BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.scale(),
|
||||
shape_value, feature_dim, rewriter);
|
||||
auto broadcast_offset =
|
||||
BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.offset(),
|
||||
shape_value, feature_dim, rewriter);
|
||||
auto broadcast_mean =
|
||||
BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.mean(),
|
||||
shape_value, feature_dim, rewriter);
|
||||
auto broadcast_stddev = BroadcastToFeatureDim(
|
||||
bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
|
||||
|
||||
// Compute:
|
||||
// scale * (input - mean) / stddev + offset
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
@ -33,6 +34,7 @@ struct TestUnfuseBatchNormPass : public FunctionPass<TestUnfuseBatchNormPass> {
|
||||
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<XlaHloDialect>();
|
||||
conversionTarget.addLegalDialect<StandardOpsDialect>();
|
||||
conversionTarget.addIllegalOp<xla_hlo::BatchNormInferenceOp>();
|
||||
|
||||
PopulateUnfuseBatchNormPatterns(&getContext(), &conversionPatterns);
|
||||
|
Loading…
Reference in New Issue
Block a user