Properly handle the default value for data_format when legalizing tf to HLO.

We used the `data_formatAttr()` in code but that does not return a default if no
attribute was specified. Use `data_format()` instead and pass StringRef instead
of StringAttr around.

PiperOrigin-RevId: 331783122
Change-Id: I42850399fe46f258263c0402bc9549a12144b63c
This commit is contained in:
Stephan Herhut 2020-09-15 09:15:27 -07:00 committed by TensorFlower Gardener
parent f810fc966a
commit 6f72965db7
3 changed files with 18 additions and 7 deletions

View File

@ -439,6 +439,17 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens
// Bias op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @biasAdd_default
func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
// CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
// CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
// CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
// CHECK-LABEL: func @biasAdd_NHWC
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0

View File

@ -119,9 +119,9 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; }
/// Returns the feature dimension for the given format and input type.
static size_t GetFeatureDimension(StringAttr format,
static size_t GetFeatureDimension(StringRef format,
RankedTensorType inputType) {
return IsDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1;
return IsDefaultDataFormat(format) ? inputType.getRank() - 1 : 1;
}
// Gets all integer values from the given attribute and push them to `values`.
@ -731,7 +731,7 @@ static void CreateWhile32(Location loc, int num_iterations,
// BatchNorm op utilities.
//===----------------------------------------------------------------------===//
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringRef format,
Value input) {
return b.getI64IntegerAttr(
GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
@ -1128,7 +1128,7 @@ class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto feature_dim = GetFeatureDimension(
op.data_formatAttr(), op.value().getType().cast<RankedTensorType>());
op.data_format(), op.value().getType().cast<RankedTensorType>());
auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
feature_dim, rewriter);
rewriter.replaceOpWithNewOp<AddOp>(op, op.value(), bias_broadcast);
@ -1814,7 +1814,7 @@ class ConvertFusedBatchNormGradBase
act = rewriter.create<ConvertOp>(loc, act, kernel_type);
auto feature_dim_attr =
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act);
getFeatureDimensionAttr(rewriter, op.data_format(), act);
auto feature_dim = feature_dim_attr.getValue().getSExtValue();
// Gets the result values.
@ -1908,7 +1908,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
LogicalResult matchAndRewrite(FusedBatchNormOpT op,
PatternRewriter &rewriter) const override {
auto feature_dim =
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
getFeatureDimensionAttr(rewriter, op.data_format(), op.x());
auto input_type_tensor = op.x().getType().template cast<TensorType>();
auto input_element_type = input_type_tensor.getElementType();

View File

@ -31,7 +31,7 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>;
//===----------------------------------------------------------------------===//
def FeatureDimension : NativeCodeCall<
"getFeatureDimensionAttr($_builder, $0, $1)">;
"getFeatureDimensionAttr($_builder, $0.getValue(), $1)">;
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;