[XLA] Transform more cases of reshape(iota) to a mixed radix calculation with
iota, multiply and add. PiperOrigin-RevId: 350572139 Change-Id: I7acb61b17594328aedb9c13f8a7aaed2decf1344
This commit is contained in:
parent
714d3ed498
commit
b18eefe629
@ -3890,16 +3890,50 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reshape(iota) -> iota.
|
// reshape(iota) -> iota or a mixed radix calculation like
|
||||||
|
// s32[2,3,4] reshape(s32[24] iota()) to
|
||||||
|
// add(
|
||||||
|
// add(s32[2,3,4] iota() iota_dimension=2,
|
||||||
|
// 4 * s32[2,3,4] iota() iota_dimension=1),
|
||||||
|
// 12 * s32[2,3,4] iota() iota_dimension=0).
|
||||||
if (operand->opcode() == HloOpcode::kIota) {
|
if (operand->opcode() == HloOpcode::kIota) {
|
||||||
auto* iota = Cast<HloIotaInstruction>(operand);
|
auto* iota = Cast<HloIotaInstruction>(operand);
|
||||||
auto opt_dims =
|
auto common_factors =
|
||||||
ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
|
CommonFactors(reshape->operand(0)->shape().dimensions(),
|
||||||
if (opt_dims.has_value()) {
|
reshape->shape().dimensions());
|
||||||
CHECK_EQ(opt_dims->size(), 1);
|
auto iota_dim = absl::c_find_if(
|
||||||
return ReplaceWithNewInstruction(
|
common_factors, [&](const std::pair<int64, int64>& dim_pair) {
|
||||||
reshape,
|
return dim_pair.first == iota->iota_dimension() &&
|
||||||
HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
|
reshape->shape().dimensions(dim_pair.second) > 1;
|
||||||
|
});
|
||||||
|
auto next_dim = absl::c_find_if(
|
||||||
|
common_factors, [&](const std::pair<int64, int64>& dim_pair) {
|
||||||
|
return dim_pair.first == iota->iota_dimension() + 1;
|
||||||
|
});
|
||||||
|
if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
|
||||||
|
int64 multiplier = 1;
|
||||||
|
HloInstruction* new_reshape = nullptr;
|
||||||
|
|
||||||
|
for (int64 dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
|
||||||
|
--dim) {
|
||||||
|
HloInstruction* new_iota = computation_->AddInstruction(
|
||||||
|
HloInstruction::CreateIota(reshape->shape(), dim));
|
||||||
|
iota->SetupDerivedInstruction(new_iota);
|
||||||
|
if (new_reshape) {
|
||||||
|
new_reshape =
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
reshape->shape(), HloOpcode::kAdd, new_reshape,
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
reshape->shape(), HloOpcode::kMultiply, new_iota,
|
||||||
|
MakeScalarLike(reshape, multiplier)))));
|
||||||
|
reshape->SetupDerivedInstruction(new_reshape);
|
||||||
|
} else {
|
||||||
|
new_reshape = new_iota;
|
||||||
|
}
|
||||||
|
multiplier *= reshape->shape().dimensions(dim);
|
||||||
|
}
|
||||||
|
reshape->SetupDerivedInstruction(new_reshape);
|
||||||
|
return ReplaceInstruction(reshape, new_reshape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3054,6 +3054,54 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
|
|||||||
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
|
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
|
||||||
|
auto m = CreateNewVerifiedModule();
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
auto iota = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
|
||||||
|
Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
|
||||||
|
builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
|
||||||
|
|
||||||
|
auto computation = m->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
|
GmockMatch(m::Reshape(m::Iota())));
|
||||||
|
|
||||||
|
AlgebraicSimplifier simplifier(default_options_);
|
||||||
|
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||||
|
|
||||||
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
|
GmockMatch(m::Add(
|
||||||
|
m::Iota(),
|
||||||
|
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
|
||||||
|
EXPECT_TRUE(
|
||||||
|
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
|
||||||
|
}
|
||||||
|
TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
|
||||||
|
auto m = CreateNewVerifiedModule();
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
auto iota = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
|
||||||
|
Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
|
||||||
|
builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
|
||||||
|
|
||||||
|
auto computation = m->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
|
GmockMatch(m::Reshape(m::Iota())));
|
||||||
|
|
||||||
|
AlgebraicSimplifier simplifier(default_options_);
|
||||||
|
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
computation->root_instruction(),
|
||||||
|
GmockMatch(m::Add(
|
||||||
|
m::Add(m::Iota(),
|
||||||
|
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
|
||||||
|
m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
|
||||||
|
EXPECT_TRUE(
|
||||||
|
ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
|
||||||
|
}
|
||||||
TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
|
TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
|
||||||
auto m = CreateNewVerifiedModule();
|
auto m = CreateNewVerifiedModule();
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
|
@ -2599,9 +2599,12 @@ xla_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_base",
|
":client_library_test_base",
|
||||||
|
":hlo_test_base",
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
":xla_internal_test_main",
|
":xla_internal_test_main",
|
||||||
|
"//tensorflow/compiler/xla:error_spec",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,13 +16,38 @@ limitations under the License.
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/error_spec.h"
|
||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
XLA_TEST_F(HloTestBase, IotaReshapeR1) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule iota_reshape
|
||||||
|
ENTRY main {
|
||||||
|
i = s32[24] iota(), iota_dimension=0
|
||||||
|
ROOT r = s32[4,3,2] reshape(i)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(HloTestBase, IotaReshapeExtraDims) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule iota_reshape
|
||||||
|
ENTRY main {
|
||||||
|
i = s32[5,5,111,42] iota(), iota_dimension=0
|
||||||
|
ROOT r = s32[25,3,37,7,6] reshape(i)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_text, absl::nullopt));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> GetR1Expected(const int64 num_elements) {
|
std::vector<T> GetR1Expected(const int64 num_elements) {
|
||||||
std::vector<T> result(num_elements);
|
std::vector<T> result(num_elements);
|
||||||
|
Loading…
Reference in New Issue
Block a user