[XLA] Fix bug in Reduce(Dot(X)) simplification.

PiperOrigin-RevId: 321703045
Change-Id: Id43655ea10d36f06cba109e8ec11efb2dfb6d80b
This commit is contained in:
Blake Hechtman 2020-07-16 20:34:15 -07:00 committed by TensorFlower Gardener
parent 5555c36e93
commit 4f7f17e469
2 changed files with 4 additions and 4 deletions

View File

@ -4137,13 +4137,13 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
new_dnums.add_rhs_contracting_dimensions(
dnums.rhs_batch_dimensions(batch_dim));
new_dnums.add_lhs_contracting_dimensions(
dnums.rhs_batch_dimensions(batch_dim));
dnums.lhs_batch_dimensions(batch_dim));
++removed_dims;
} else {
new_dnums.add_rhs_batch_dimensions(
dnums.rhs_batch_dimensions(batch_dim));
new_dnums.add_lhs_batch_dimensions(
dnums.rhs_batch_dimensions(batch_dim));
dnums.lhs_batch_dimensions(batch_dim));
}
}
std::vector<int64> reduce_dims;

View File

@ -6145,10 +6145,10 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
}
test {
p0 = f32[32,8,5,6] parameter(0)
p1 = f32[32,8,6,7] parameter(1)
p1 = f32[8,32,6,7] parameter(1)
d = f32[32,8,5,7] dot(p0, p1),
lhs_batch_dims={0,1},
rhs_batch_dims={0,1},
rhs_batch_dims={1,0},
rhs_contracting_dims={2},
lhs_contracting_dims={3}
c = f32[] constant(0)