[XLA] Transform more cases of reshape(iota) to a mixed radix calculation with
iota, multiply and add. PiperOrigin-RevId: 350572139 Change-Id: I7acb61b17594328aedb9c13f8a7aaed2decf1344
This commit is contained in:
		
							parent
							
								
									714d3ed498
								
							
						
					
					
						commit
						b18eefe629
					
				@ -3890,16 +3890,50 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // reshape(iota) -> iota.
 | 
					  // reshape(iota) -> iota or a mixed radix calculation like
 | 
				
			||||||
 | 
					  // s32[2,3,4] reshape(s32[24] iota()) to
 | 
				
			||||||
 | 
					  // add(
 | 
				
			||||||
 | 
					  //    add(s32[2,3,4] iota() iota_dimension=2,
 | 
				
			||||||
 | 
					  //        4 * s32[2,3,4] iota() iota_dimension=1),
 | 
				
			||||||
 | 
					  //    12 * s32[2,3,4] iota() iota_dimension=0).
 | 
				
			||||||
  if (operand->opcode() == HloOpcode::kIota) {
 | 
					  if (operand->opcode() == HloOpcode::kIota) {
 | 
				
			||||||
    auto* iota = Cast<HloIotaInstruction>(operand);
 | 
					    auto* iota = Cast<HloIotaInstruction>(operand);
 | 
				
			||||||
    auto opt_dims =
 | 
					    auto common_factors =
 | 
				
			||||||
        ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
 | 
					        CommonFactors(reshape->operand(0)->shape().dimensions(),
 | 
				
			||||||
    if (opt_dims.has_value()) {
 | 
					                      reshape->shape().dimensions());
 | 
				
			||||||
      CHECK_EQ(opt_dims->size(), 1);
 | 
					    auto iota_dim = absl::c_find_if(
 | 
				
			||||||
      return ReplaceWithNewInstruction(
 | 
					        common_factors, [&](const std::pair<int64, int64>& dim_pair) {
 | 
				
			||||||
          reshape,
 | 
					          return dim_pair.first == iota->iota_dimension() &&
 | 
				
			||||||
          HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
 | 
					                 reshape->shape().dimensions(dim_pair.second) > 1;
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					    auto next_dim = absl::c_find_if(
 | 
				
			||||||
 | 
					        common_factors, [&](const std::pair<int64, int64>& dim_pair) {
 | 
				
			||||||
 | 
					          return dim_pair.first == iota->iota_dimension() + 1;
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					    if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
 | 
				
			||||||
 | 
					      int64 multiplier = 1;
 | 
				
			||||||
 | 
					      HloInstruction* new_reshape = nullptr;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for (int64 dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
 | 
				
			||||||
 | 
					           --dim) {
 | 
				
			||||||
 | 
					        HloInstruction* new_iota = computation_->AddInstruction(
 | 
				
			||||||
 | 
					            HloInstruction::CreateIota(reshape->shape(), dim));
 | 
				
			||||||
 | 
					        iota->SetupDerivedInstruction(new_iota);
 | 
				
			||||||
 | 
					        if (new_reshape) {
 | 
				
			||||||
 | 
					          new_reshape =
 | 
				
			||||||
 | 
					              computation_->AddInstruction(HloInstruction::CreateBinary(
 | 
				
			||||||
 | 
					                  reshape->shape(), HloOpcode::kAdd, new_reshape,
 | 
				
			||||||
 | 
					                  computation_->AddInstruction(HloInstruction::CreateBinary(
 | 
				
			||||||
 | 
					                      reshape->shape(), HloOpcode::kMultiply, new_iota,
 | 
				
			||||||
 | 
					                      MakeScalarLike(reshape, multiplier)))));
 | 
				
			||||||
 | 
					          reshape->SetupDerivedInstruction(new_reshape);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          new_reshape = new_iota;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        multiplier *= reshape->shape().dimensions(dim);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      reshape->SetupDerivedInstruction(new_reshape);
 | 
				
			||||||
 | 
					      return ReplaceInstruction(reshape, new_reshape);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -3054,6 +3054,54 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
 | 
				
			|||||||
      ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
 | 
					      ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
 | 
				
			||||||
 | 
					  auto m = CreateNewVerifiedModule();
 | 
				
			||||||
 | 
					  HloComputation::Builder builder(TestName());
 | 
				
			||||||
 | 
					  auto iota = builder.AddInstruction(
 | 
				
			||||||
 | 
					      HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
 | 
				
			||||||
 | 
					  Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
 | 
				
			||||||
 | 
					  builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto computation = m->AddEntryComputation(builder.Build());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(computation->root_instruction(),
 | 
				
			||||||
 | 
					              GmockMatch(m::Reshape(m::Iota())));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AlgebraicSimplifier simplifier(default_options_);
 | 
				
			||||||
 | 
					  ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(computation->root_instruction(),
 | 
				
			||||||
 | 
					              GmockMatch(m::Add(
 | 
				
			||||||
 | 
					                  m::Iota(),
 | 
				
			||||||
 | 
					                  m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
 | 
				
			||||||
 | 
					  EXPECT_TRUE(
 | 
				
			||||||
 | 
					      ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
 | 
				
			||||||
 | 
					  auto m = CreateNewVerifiedModule();
 | 
				
			||||||
 | 
					  HloComputation::Builder builder(TestName());
 | 
				
			||||||
 | 
					  auto iota = builder.AddInstruction(
 | 
				
			||||||
 | 
					      HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
 | 
				
			||||||
 | 
					  Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
 | 
				
			||||||
 | 
					  builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto computation = m->AddEntryComputation(builder.Build());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(computation->root_instruction(),
 | 
				
			||||||
 | 
					              GmockMatch(m::Reshape(m::Iota())));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AlgebraicSimplifier simplifier(default_options_);
 | 
				
			||||||
 | 
					  ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(
 | 
				
			||||||
 | 
					      computation->root_instruction(),
 | 
				
			||||||
 | 
					      GmockMatch(m::Add(
 | 
				
			||||||
 | 
					          m::Add(m::Iota(),
 | 
				
			||||||
 | 
					                 m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
 | 
				
			||||||
 | 
					          m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
 | 
				
			||||||
 | 
					  EXPECT_TRUE(
 | 
				
			||||||
 | 
					      ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
 | 
					TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
 | 
				
			||||||
  auto m = CreateNewVerifiedModule();
 | 
					  auto m = CreateNewVerifiedModule();
 | 
				
			||||||
  HloComputation::Builder builder(TestName());
 | 
					  HloComputation::Builder builder(TestName());
 | 
				
			||||||
 | 
				
			|||||||
@ -2599,9 +2599,12 @@ xla_test(
 | 
				
			|||||||
    ],
 | 
					    ],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":client_library_test_base",
 | 
					        ":client_library_test_base",
 | 
				
			||||||
 | 
					        ":hlo_test_base",
 | 
				
			||||||
        ":test_macros_header",
 | 
					        ":test_macros_header",
 | 
				
			||||||
        ":xla_internal_test_main",
 | 
					        ":xla_internal_test_main",
 | 
				
			||||||
 | 
					        "//tensorflow/compiler/xla:error_spec",
 | 
				
			||||||
        "//tensorflow/core:lib",
 | 
					        "//tensorflow/core:lib",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/types:optional",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -16,13 +16,38 @@ limitations under the License.
 | 
				
			|||||||
#include <numeric>
 | 
					#include <numeric>
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/types/optional.h"
 | 
				
			||||||
 | 
					#include "tensorflow/compiler/xla/error_spec.h"
 | 
				
			||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
 | 
					#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
 | 
				
			||||||
 | 
					#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 | 
				
			||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
 | 
					#include "tensorflow/compiler/xla/tests/test_macros.h"
 | 
				
			||||||
#include "tensorflow/core/lib/core/errors.h"
 | 
					#include "tensorflow/core/lib/core/errors.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace xla {
 | 
					namespace xla {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					XLA_TEST_F(HloTestBase, IotaReshapeR1) {
 | 
				
			||||||
 | 
					  const string hlo_text = R"(
 | 
				
			||||||
 | 
					  HloModule iota_reshape
 | 
				
			||||||
 | 
					  ENTRY main {
 | 
				
			||||||
 | 
					    i = s32[24] iota(), iota_dimension=0
 | 
				
			||||||
 | 
					    ROOT r = s32[4,3,2] reshape(i)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					)";
 | 
				
			||||||
 | 
					  EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					XLA_TEST_F(HloTestBase, IotaReshapeExtraDims) {
 | 
				
			||||||
 | 
					  const string hlo_text = R"(
 | 
				
			||||||
 | 
					  HloModule iota_reshape
 | 
				
			||||||
 | 
					  ENTRY main {
 | 
				
			||||||
 | 
					    i = s32[5,5,111,42] iota(), iota_dimension=0
 | 
				
			||||||
 | 
					    ROOT r = s32[25,3,37,7,6] reshape(i)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					)";
 | 
				
			||||||
 | 
					  EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
std::vector<T> GetR1Expected(const int64 num_elements) {
 | 
					std::vector<T> GetR1Expected(const int64 num_elements) {
 | 
				
			||||||
  std::vector<T> result(num_elements);
 | 
					  std::vector<T> result(num_elements);
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user