[XLA] Fix bug in Reduce(Dot(X)) simplification.
PiperOrigin-RevId: 321703045 Change-Id: Id43655ea10d36f06cba109e8ec11efb2dfb6d80b
This commit is contained in:
parent
5555c36e93
commit
4f7f17e469
@ -4137,13 +4137,13 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
|||||||
new_dnums.add_rhs_contracting_dimensions(
|
new_dnums.add_rhs_contracting_dimensions(
|
||||||
dnums.rhs_batch_dimensions(batch_dim));
|
dnums.rhs_batch_dimensions(batch_dim));
|
||||||
new_dnums.add_lhs_contracting_dimensions(
|
new_dnums.add_lhs_contracting_dimensions(
|
||||||
dnums.rhs_batch_dimensions(batch_dim));
|
dnums.lhs_batch_dimensions(batch_dim));
|
||||||
++removed_dims;
|
++removed_dims;
|
||||||
} else {
|
} else {
|
||||||
new_dnums.add_rhs_batch_dimensions(
|
new_dnums.add_rhs_batch_dimensions(
|
||||||
dnums.rhs_batch_dimensions(batch_dim));
|
dnums.rhs_batch_dimensions(batch_dim));
|
||||||
new_dnums.add_lhs_batch_dimensions(
|
new_dnums.add_lhs_batch_dimensions(
|
||||||
dnums.rhs_batch_dimensions(batch_dim));
|
dnums.lhs_batch_dimensions(batch_dim));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<int64> reduce_dims;
|
std::vector<int64> reduce_dims;
|
||||||
|
@ -6145,10 +6145,10 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
|
|||||||
}
|
}
|
||||||
test {
|
test {
|
||||||
p0 = f32[32,8,5,6] parameter(0)
|
p0 = f32[32,8,5,6] parameter(0)
|
||||||
p1 = f32[32,8,6,7] parameter(1)
|
p1 = f32[8,32,6,7] parameter(1)
|
||||||
d = f32[32,8,5,7] dot(p0, p1),
|
d = f32[32,8,5,7] dot(p0, p1),
|
||||||
lhs_batch_dims={0,1},
|
lhs_batch_dims={0,1},
|
||||||
rhs_batch_dims={0,1},
|
rhs_batch_dims={1,0},
|
||||||
rhs_contracting_dims={2},
|
rhs_contracting_dims={2},
|
||||||
lhs_contracting_dims={3}
|
lhs_contracting_dims={3}
|
||||||
c = f32[] constant(0)
|
c = f32[] constant(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user