Fix bug in Canonicalizer folder function for ArithmeticOp.
PiperOrigin-RevId: 312224624 Change-Id: Icd6b5ed25fedfa4b4f99be0d09fc5746010aad2a
This commit is contained in:
parent
97aed8f72e
commit
3c6dadd17f
@ -497,6 +497,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
return arithmetic_op.x();
|
||||
}
|
||||
|
||||
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
|
||||
// TODO(chhe): we could fold and add an identity to force the broadcast.
|
||||
if (result_op_type != rhs_type) {
|
||||
return {};
|
||||
}
|
||||
|
||||
bool is_symmetric =
|
||||
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
|
||||
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
|
@ -431,3 +431,14 @@ func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// Test no fold because of the broadcast.
|
||||
func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> {
|
||||
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
|
||||
return %1 : tensor<1x6x8x1xf32>
|
||||
// CHECK-LABEL: DontRemoveTrivialMul
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user