From 4f7f17e4694353d1f2c08056fe593c22d651a331 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Thu, 16 Jul 2020 20:34:15 -0700 Subject: [PATCH] [XLA] Fix bug in Reduce(Dot(X)) simplification. PiperOrigin-RevId: 321703045 Change-Id: Id43655ea10d36f06cba109e8ec11efb2dfb6d80b --- tensorflow/compiler/xla/service/algebraic_simplifier.cc | 4 ++-- tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 741edfc7c35..3e012fc41b8 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -4137,13 +4137,13 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { new_dnums.add_rhs_contracting_dimensions( dnums.rhs_batch_dimensions(batch_dim)); new_dnums.add_lhs_contracting_dimensions( - dnums.rhs_batch_dimensions(batch_dim)); + dnums.lhs_batch_dimensions(batch_dim)); ++removed_dims; } else { new_dnums.add_rhs_batch_dimensions( dnums.rhs_batch_dimensions(batch_dim)); new_dnums.add_lhs_batch_dimensions( - dnums.rhs_batch_dimensions(batch_dim)); + dnums.lhs_batch_dimensions(batch_dim)); } } std::vector reduce_dims; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 779d6c9cdc5..d2c32d79a91 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6145,10 +6145,10 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) { } test { p0 = f32[32,8,5,6] parameter(0) - p1 = f32[32,8,6,7] parameter(1) + p1 = f32[8,32,6,7] parameter(1) d = f32[32,8,5,7] dot(p0, p1), lhs_batch_dims={0,1}, - rhs_batch_dims={0,1}, + rhs_batch_dims={1,0}, rhs_contracting_dims={2}, lhs_contracting_dims={3} c = f32[] constant(0)