[XLA] Fix bug in Reduce(Dot(X)) simplification.
PiperOrigin-RevId: 321703045 Change-Id: Id43655ea10d36f06cba109e8ec11efb2dfb6d80b
This commit is contained in:
parent
5555c36e93
commit
4f7f17e469
@ -4137,13 +4137,13 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
new_dnums.add_rhs_contracting_dimensions(
|
||||
dnums.rhs_batch_dimensions(batch_dim));
|
||||
new_dnums.add_lhs_contracting_dimensions(
|
||||
dnums.rhs_batch_dimensions(batch_dim));
|
||||
dnums.lhs_batch_dimensions(batch_dim));
|
||||
++removed_dims;
|
||||
} else {
|
||||
new_dnums.add_rhs_batch_dimensions(
|
||||
dnums.rhs_batch_dimensions(batch_dim));
|
||||
new_dnums.add_lhs_batch_dimensions(
|
||||
dnums.rhs_batch_dimensions(batch_dim));
|
||||
dnums.lhs_batch_dimensions(batch_dim));
|
||||
}
|
||||
}
|
||||
std::vector<int64> reduce_dims;
|
||||
|
@ -6145,10 +6145,10 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
|
||||
}
|
||||
test {
|
||||
p0 = f32[32,8,5,6] parameter(0)
|
||||
p1 = f32[32,8,6,7] parameter(1)
|
||||
p1 = f32[8,32,6,7] parameter(1)
|
||||
d = f32[32,8,5,7] dot(p0, p1),
|
||||
lhs_batch_dims={0,1},
|
||||
rhs_batch_dims={0,1},
|
||||
rhs_batch_dims={1,0},
|
||||
rhs_contracting_dims={2},
|
||||
lhs_contracting_dims={3}
|
||||
c = f32[] constant(0)
|
||||
|
Loading…
Reference in New Issue
Block a user