Fix bug in Canonicalizer folder function for ArithmeticOp.

PiperOrigin-RevId: 312224624
Change-Id: Icd6b5ed25fedfa4b4f99be0d09fc5746010aad2a
This commit is contained in:
Chuan He 2020-05-18 23:15:56 -07:00 committed by TensorFlower Gardener
parent 97aed8f72e
commit 3c6dadd17f
2 changed files with 17 additions and 0 deletions

View File

@ -497,6 +497,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
return arithmetic_op.x();
}
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
// TODO(chhe): we could fold and add an identity to force the broadcast.
if (result_op_type != rhs_type) {
return {};
}
bool is_symmetric =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {

View File

@ -431,3 +431,14 @@ func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
// CHECK: return %[[RESULT]] : tensor<?x?xf32>
}
// Test no fold because of the broadcast.
func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> {
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
return %1 : tensor<1x6x8x1xf32>
// CHECK-LABEL: DontRemoveTrivialMul
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
// CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32>
}