[XLA] Transform more cases of reshape(iota) to a mixed radix calculation with

iota, multiply and add.

PiperOrigin-RevId: 350572139
Change-Id: I7acb61b17594328aedb9c13f8a7aaed2decf1344
This commit is contained in:
Blake Hechtman 2021-01-07 08:55:58 -08:00 committed by TensorFlower Gardener
parent 714d3ed498
commit b18eefe629
4 changed files with 118 additions and 8 deletions

View File

@ -3890,16 +3890,50 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
} }
} }
// reshape(iota) -> iota. // reshape(iota) -> iota or a mixed radix calculation like
// s32[2,3,4] reshape(s32[24] iota()) to
// add(
// add(s32[2,3,4] iota() iota_dimension=2,
// 4 * s32[2,3,4] iota() iota_dimension=1),
// 12 * s32[2,3,4] iota() iota_dimension=0).
if (operand->opcode() == HloOpcode::kIota) { if (operand->opcode() == HloOpcode::kIota) {
auto* iota = Cast<HloIotaInstruction>(operand); auto* iota = Cast<HloIotaInstruction>(operand);
auto opt_dims = auto common_factors =
ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()}); CommonFactors(reshape->operand(0)->shape().dimensions(),
if (opt_dims.has_value()) { reshape->shape().dimensions());
CHECK_EQ(opt_dims->size(), 1); auto iota_dim = absl::c_find_if(
return ReplaceWithNewInstruction( common_factors, [&](const std::pair<int64, int64>& dim_pair) {
reshape, return dim_pair.first == iota->iota_dimension() &&
HloInstruction::CreateIota(reshape->shape(), opt_dims->front())); reshape->shape().dimensions(dim_pair.second) > 1;
});
auto next_dim = absl::c_find_if(
common_factors, [&](const std::pair<int64, int64>& dim_pair) {
return dim_pair.first == iota->iota_dimension() + 1;
});
if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
int64 multiplier = 1;
HloInstruction* new_reshape = nullptr;
for (int64 dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
--dim) {
HloInstruction* new_iota = computation_->AddInstruction(
HloInstruction::CreateIota(reshape->shape(), dim));
iota->SetupDerivedInstruction(new_iota);
if (new_reshape) {
new_reshape =
computation_->AddInstruction(HloInstruction::CreateBinary(
reshape->shape(), HloOpcode::kAdd, new_reshape,
computation_->AddInstruction(HloInstruction::CreateBinary(
reshape->shape(), HloOpcode::kMultiply, new_iota,
MakeScalarLike(reshape, multiplier)))));
reshape->SetupDerivedInstruction(new_reshape);
} else {
new_reshape = new_iota;
}
multiplier *= reshape->shape().dimensions(dim);
}
reshape->SetupDerivedInstruction(new_reshape);
return ReplaceInstruction(reshape, new_reshape);
} }
} }

View File

@ -3054,6 +3054,54 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
} }
TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
auto m = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
auto iota = builder.AddInstruction(
HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
auto computation = m->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Reshape(m::Iota())));
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Add(
m::Iota(),
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
EXPECT_TRUE(
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
}
TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
auto m = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
auto iota = builder.AddInstruction(
HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
auto computation = m->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Reshape(m::Iota())));
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(
computation->root_instruction(),
GmockMatch(m::Add(
m::Add(m::Iota(),
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
EXPECT_TRUE(
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
}
TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
auto m = CreateNewVerifiedModule(); auto m = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName()); HloComputation::Builder builder(TestName());

View File

@ -2599,9 +2599,12 @@ xla_test(
], ],
deps = [ deps = [
":client_library_test_base", ":client_library_test_base",
":hlo_test_base",
":test_macros_header", ":test_macros_header",
":xla_internal_test_main", ":xla_internal_test_main",
"//tensorflow/compiler/xla:error_spec",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -16,13 +16,38 @@ limitations under the License.
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/error_spec.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
namespace xla { namespace xla {
namespace { namespace {
XLA_TEST_F(HloTestBase, IotaReshapeR1) {
const string hlo_text = R"(
HloModule iota_reshape
ENTRY main {
i = s32[24] iota(), iota_dimension=0
ROOT r = s32[4,3,2] reshape(i)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
}
XLA_TEST_F(HloTestBase, IotaReshapeExtraDims) {
const string hlo_text = R"(
HloModule iota_reshape
ENTRY main {
i = s32[5,5,111,42] iota(), iota_dimension=0
ROOT r = s32[25,3,37,7,6] reshape(i)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
}
template <typename T> template <typename T>
std::vector<T> GetR1Expected(const int64 num_elements) { std::vector<T> GetR1Expected(const int64 num_elements) {
std::vector<T> result(num_elements); std::vector<T> result(num_elements);