Add tf.BiasAddV1 op and canonicalize to tf.BiasAdd.

The difference between the two ops is that BiasAdd has a data format attribute. We canonicalize from v1 by using the default data format for this attribute.

PiperOrigin-RevId: 314872733
Change-Id: I947a0a40a4aeca3262319a2b8b9dbf8b8b11328a
This commit is contained in:
Lucy Fox 2020-06-04 23:07:45 -07:00 committed by TensorFlower Gardener
parent 9221044560
commit 2d2500117e
4 changed files with 46 additions and 0 deletions

View File

@ -917,6 +917,30 @@ the feature dimension is the third-to-last.
}];
}
def TF_BiasAddV1Op : TF_Op<"BiasAddV1", [NoSideEffect]> {
let summary = "Adds `bias` to `value`.";
let description = [{
This is a deprecated version of BiasAdd and will be soon removed.
This is a special case of `tf.add` where `bias` is restricted to be 1-D.
Broadcasting is supported, so `value` may have any number of dimensions.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
}
def TF_BitcastOp : TF_Op<"Bitcast", [NoSideEffect]> {
let summary = [{
Bitcasts a tensor from one type to another without copying data.

View File

@ -755,6 +755,15 @@ static LogicalResult Verify(BiasAddGradOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// BiasAddV1Op
//===----------------------------------------------------------------------===//
void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<BiasAddV1ToBiasAdd>(context);
}
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//

View File

@ -34,6 +34,13 @@ func @testBatchMatMulV2ToMatMul(%arg0: tensor<4x3xf32>, %arg1: tensor<4x5xf32>)
// CHECK: return %0
}
// CHECK-LABEL: testBiasAddV1ToBiasAdd
func @testBiasAddV1ToBiasAdd(%arg0: tensor<*xf32>, %arg1: tensor<128xf32>) -> tensor<*xf32> {
// CHECK: "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%0 = "tf.BiasAddV1"(%arg0, %arg1) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
return %0: tensor<*xf32>
}
// CHECK-LABEL: func @testLeakyRelu
func @testLeakyRelu(%arg0 : tensor<16xf32>) -> (tensor<16xf32>) {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 1.0 : f32} : (tensor<16xf32>) -> tensor<16xf32>

View File

@ -65,6 +65,12 @@ def BatchMatMulV2ToMatMul : Pat<(TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y),
(TF_MatMulOp $x, $y, $adj_x, $adj_y),
[(IsRank2Tensor $x), (IsRank2Tensor $y)]>;
//===----------------------------------------------------------------------===//
// BiasAddV1 op patterns.
//===----------------------------------------------------------------------===//
def BiasAddV1ToBiasAdd : Pat<(TF_BiasAddV1Op $arg0, $arg1),
(TF_BiasAddOp $arg0, $arg1, ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">)>;
//===----------------------------------------------------------------------===//
// Bitcast op patterns.