[XLA] Transform more cases of reshape(iota) to a mixed radix calculation with
iota, multiply and add. PiperOrigin-RevId: 350572139 Change-Id: I7acb61b17594328aedb9c13f8a7aaed2decf1344
This commit is contained in:
parent
714d3ed498
commit
b18eefe629
@ -3890,16 +3890,50 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
|
||||
}
|
||||
}
|
||||
|
||||
// reshape(iota) -> iota.
|
||||
// reshape(iota) -> iota or a mixed radix calculation like
|
||||
// s32[2,3,4] reshape(s32[24] iota()) to
|
||||
// add(
|
||||
// add(s32[2,3,4] iota() iota_dimension=2,
|
||||
// 4 * s32[2,3,4] iota() iota_dimension=1),
|
||||
// 12 * s32[2,3,4] iota() iota_dimension=0).
|
||||
if (operand->opcode() == HloOpcode::kIota) {
|
||||
auto* iota = Cast<HloIotaInstruction>(operand);
|
||||
auto opt_dims =
|
||||
ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
|
||||
if (opt_dims.has_value()) {
|
||||
CHECK_EQ(opt_dims->size(), 1);
|
||||
return ReplaceWithNewInstruction(
|
||||
reshape,
|
||||
HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
|
||||
auto common_factors =
|
||||
CommonFactors(reshape->operand(0)->shape().dimensions(),
|
||||
reshape->shape().dimensions());
|
||||
auto iota_dim = absl::c_find_if(
|
||||
common_factors, [&](const std::pair<int64, int64>& dim_pair) {
|
||||
return dim_pair.first == iota->iota_dimension() &&
|
||||
reshape->shape().dimensions(dim_pair.second) > 1;
|
||||
});
|
||||
auto next_dim = absl::c_find_if(
|
||||
common_factors, [&](const std::pair<int64, int64>& dim_pair) {
|
||||
return dim_pair.first == iota->iota_dimension() + 1;
|
||||
});
|
||||
if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
|
||||
int64 multiplier = 1;
|
||||
HloInstruction* new_reshape = nullptr;
|
||||
|
||||
for (int64 dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
|
||||
--dim) {
|
||||
HloInstruction* new_iota = computation_->AddInstruction(
|
||||
HloInstruction::CreateIota(reshape->shape(), dim));
|
||||
iota->SetupDerivedInstruction(new_iota);
|
||||
if (new_reshape) {
|
||||
new_reshape =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
reshape->shape(), HloOpcode::kAdd, new_reshape,
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
reshape->shape(), HloOpcode::kMultiply, new_iota,
|
||||
MakeScalarLike(reshape, multiplier)))));
|
||||
reshape->SetupDerivedInstruction(new_reshape);
|
||||
} else {
|
||||
new_reshape = new_iota;
|
||||
}
|
||||
multiplier *= reshape->shape().dimensions(dim);
|
||||
}
|
||||
reshape->SetupDerivedInstruction(new_reshape);
|
||||
return ReplaceInstruction(reshape, new_reshape);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3054,6 +3054,54 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
|
||||
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto iota = builder.AddInstruction(
|
||||
HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
|
||||
Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
|
||||
|
||||
auto computation = m->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Reshape(m::Iota())));
|
||||
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Add(
|
||||
m::Iota(),
|
||||
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
|
||||
}
|
||||
TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto iota = builder.AddInstruction(
|
||||
HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
|
||||
Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
|
||||
|
||||
auto computation = m->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Reshape(m::Iota())));
|
||||
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
|
||||
EXPECT_THAT(
|
||||
computation->root_instruction(),
|
||||
GmockMatch(m::Add(
|
||||
m::Add(m::Iota(),
|
||||
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
|
||||
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
|
||||
}
|
||||
TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
@ -2599,9 +2599,12 @@ xla_test(
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:error_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,13 +16,38 @@ limitations under the License.
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/error_spec.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
XLA_TEST_F(HloTestBase, IotaReshapeR1) {
|
||||
const string hlo_text = R"(
|
||||
HloModule iota_reshape
|
||||
ENTRY main {
|
||||
i = s32[24] iota(), iota_dimension=0
|
||||
ROOT r = s32[4,3,2] reshape(i)
|
||||
}
|
||||
)";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloTestBase, IotaReshapeExtraDims) {
|
||||
const string hlo_text = R"(
|
||||
HloModule iota_reshape
|
||||
ENTRY main {
|
||||
i = s32[5,5,111,42] iota(), iota_dimension=0
|
||||
ROOT r = s32[25,3,37,7,6] reshape(i)
|
||||
}
|
||||
)";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetR1Expected(const int64 num_elements) {
|
||||
std::vector<T> result(num_elements);
|
||||
|
Loading…
Reference in New Issue
Block a user