[XLA] Fix bug in Reduce(Dot(X)) simplification.

PiperOrigin-RevId: 321703045
Change-Id: Id43655ea10d36f06cba109e8ec11efb2dfb6d80b
This commit is contained in:
Blake Hechtman 2020-07-16 20:34:15 -07:00 committed by TensorFlower Gardener
parent 5555c36e93
commit 4f7f17e469
2 changed files with 4 additions and 4 deletions

View File

@ -4137,13 +4137,13 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
new_dnums.add_rhs_contracting_dimensions( new_dnums.add_rhs_contracting_dimensions(
dnums.rhs_batch_dimensions(batch_dim)); dnums.rhs_batch_dimensions(batch_dim));
new_dnums.add_lhs_contracting_dimensions( new_dnums.add_lhs_contracting_dimensions(
dnums.rhs_batch_dimensions(batch_dim)); dnums.lhs_batch_dimensions(batch_dim));
++removed_dims; ++removed_dims;
} else { } else {
new_dnums.add_rhs_batch_dimensions( new_dnums.add_rhs_batch_dimensions(
dnums.rhs_batch_dimensions(batch_dim)); dnums.rhs_batch_dimensions(batch_dim));
new_dnums.add_lhs_batch_dimensions( new_dnums.add_lhs_batch_dimensions(
dnums.rhs_batch_dimensions(batch_dim)); dnums.lhs_batch_dimensions(batch_dim));
} }
} }
std::vector<int64> reduce_dims; std::vector<int64> reduce_dims;

View File

@ -6145,10 +6145,10 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
} }
test { test {
p0 = f32[32,8,5,6] parameter(0) p0 = f32[32,8,5,6] parameter(0)
p1 = f32[32,8,6,7] parameter(1) p1 = f32[8,32,6,7] parameter(1)
d = f32[32,8,5,7] dot(p0, p1), d = f32[32,8,5,7] dot(p0, p1),
lhs_batch_dims={0,1}, lhs_batch_dims={0,1},
rhs_batch_dims={0,1}, rhs_batch_dims={1,0},
rhs_contracting_dims={2}, rhs_contracting_dims={2},
lhs_contracting_dims={3} lhs_contracting_dims={3}
c = f32[] constant(0) c = f32[] constant(0)