Properly handle the default value for data_format when legalizing tf to HLO.
We used the `data_formatAttr()` in code but that does not return a default if no attribute was specified. Use `data_format()` instead and pass StringRef instead of StringAttr around. PiperOrigin-RevId: 331783122 Change-Id: I42850399fe46f258263c0402bc9549a12144b63c
This commit is contained in:
parent
f810fc966a
commit
6f72965db7
@ -439,6 +439,17 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens
|
||||
// Bias op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_default
|
||||
func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
|
||||
// CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
|
||||
// CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
|
||||
// CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC
|
||||
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
|
||||
|
@ -119,9 +119,9 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; }
|
||||
|
||||
/// Returns the feature dimension for the given format and input type.
|
||||
static size_t GetFeatureDimension(StringAttr format,
|
||||
static size_t GetFeatureDimension(StringRef format,
|
||||
RankedTensorType inputType) {
|
||||
return IsDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1;
|
||||
return IsDefaultDataFormat(format) ? inputType.getRank() - 1 : 1;
|
||||
}
|
||||
|
||||
// Gets all integer values from the given attribute and push them to `values`.
|
||||
@ -731,7 +731,7 @@ static void CreateWhile32(Location loc, int num_iterations,
|
||||
// BatchNorm op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
|
||||
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringRef format,
|
||||
Value input) {
|
||||
return b.getI64IntegerAttr(
|
||||
GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
|
||||
@ -1128,7 +1128,7 @@ class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto feature_dim = GetFeatureDimension(
|
||||
op.data_formatAttr(), op.value().getType().cast<RankedTensorType>());
|
||||
op.data_format(), op.value().getType().cast<RankedTensorType>());
|
||||
auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
|
||||
feature_dim, rewriter);
|
||||
rewriter.replaceOpWithNewOp<AddOp>(op, op.value(), bias_broadcast);
|
||||
@ -1814,7 +1814,7 @@ class ConvertFusedBatchNormGradBase
|
||||
act = rewriter.create<ConvertOp>(loc, act, kernel_type);
|
||||
|
||||
auto feature_dim_attr =
|
||||
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act);
|
||||
getFeatureDimensionAttr(rewriter, op.data_format(), act);
|
||||
auto feature_dim = feature_dim_attr.getValue().getSExtValue();
|
||||
|
||||
// Gets the result values.
|
||||
@ -1908,7 +1908,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
|
||||
LogicalResult matchAndRewrite(FusedBatchNormOpT op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto feature_dim =
|
||||
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
|
||||
getFeatureDimensionAttr(rewriter, op.data_format(), op.x());
|
||||
|
||||
auto input_type_tensor = op.x().getType().template cast<TensorType>();
|
||||
auto input_element_type = input_type_tensor.getElementType();
|
||||
|
@ -31,7 +31,7 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FeatureDimension : NativeCodeCall<
|
||||
"getFeatureDimensionAttr($_builder, $0, $1)">;
|
||||
"getFeatureDimensionAttr($_builder, $0.getValue(), $1)">;
|
||||
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
|
||||
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user