Populate tiling info in Layout data.
PiperOrigin-RevId: 240596416
This commit is contained in:
parent
c510b79d5c
commit
f366225ec6
@ -96,8 +96,13 @@ string Layout::ToString() const {
|
||||
}
|
||||
|
||||
bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
|
||||
if (lhs.format() != rhs.format() ||
|
||||
lhs.minor_to_major() != rhs.minor_to_major() ||
|
||||
if (lhs.format() != rhs.format()) {
|
||||
return false;
|
||||
}
|
||||
if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) {
|
||||
return false;
|
||||
}
|
||||
if (lhs.format() == SPARSE &&
|
||||
lhs.max_sparse_elements() != rhs.max_sparse_elements()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -127,6 +127,12 @@ class Layout {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Equal& MinorToMajorOnly() {
|
||||
ignore_tiles_ = true;
|
||||
ignore_element_size_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
bool ignore_tiles_ = false;
|
||||
bool ignore_element_size_ = false;
|
||||
|
@ -250,12 +250,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Runs the visitor on a computation.
|
||||
static bool Run(HloComputation* computation,
|
||||
const AlgebraicSimplifierOptions& options);
|
||||
const AlgebraicSimplifierOptions& options,
|
||||
AlgebraicSimplifier* simplifier);
|
||||
|
||||
private:
|
||||
explicit AlgebraicSimplifierVisitor(HloComputation* computation,
|
||||
const AlgebraicSimplifierOptions& options)
|
||||
: computation_(computation), options_(options) {}
|
||||
const AlgebraicSimplifierOptions& options,
|
||||
AlgebraicSimplifier* simplifier)
|
||||
: computation_(computation), options_(options), simplifier_(simplifier) {}
|
||||
|
||||
// Transforms Dots where at least one input is a vector or has a degenerate
|
||||
// dimension and converts it into a multiply and reduce. This should enable
|
||||
@ -274,10 +276,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
if (hlo->shape().rank() == 1) {
|
||||
return hlo;
|
||||
}
|
||||
return computation_->AddInstruction(HloInstruction::CreateReshape(
|
||||
auto hlo_instruction =
|
||||
computation_->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(hlo->shape().element_type(),
|
||||
{ShapeUtil::ElementsIn(hlo->shape())}),
|
||||
hlo));
|
||||
simplifier_->UpdateLayout(hlo_instruction->mutable_shape());
|
||||
return hlo_instruction;
|
||||
}
|
||||
|
||||
// Converts to primitive type if the input hlo is not that type, otherwise
|
||||
@ -287,8 +292,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
if (hlo->shape().element_type() == element_type) {
|
||||
return hlo;
|
||||
}
|
||||
return computation_->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
|
||||
Shape changed_shape =
|
||||
ShapeUtil::ChangeElementType(hlo->shape(), element_type);
|
||||
simplifier_->UpdateLayout(&changed_shape);
|
||||
return computation_->AddInstruction(
|
||||
HloInstruction::CreateConvert(changed_shape, hlo));
|
||||
}
|
||||
|
||||
// Transposes a dot operand such that the batch dimensions are the msot major,
|
||||
@ -312,13 +320,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Helper method to perform and add reduction on a list of dimensions.
|
||||
HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) {
|
||||
HloInstruction* zero =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
HloInstruction* zero = computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
|
||||
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
|
||||
Shape shape = ShapeUtil::FilterDimensions(
|
||||
[&](int64 dim) { return !absl::c_linear_search(dims, dim); },
|
||||
hlo->shape());
|
||||
simplifier_->UpdateLayout(&shape);
|
||||
return computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
shape, hlo, zero, dims, AddReduce_computation));
|
||||
}
|
||||
@ -403,6 +412,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
HloComputation::Builder b("scalar_add_computation");
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||
simplifier_->UpdateLayout(&shape);
|
||||
auto scalar_lhs = b.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
|
||||
auto scalar_rhs = b.AddInstruction(
|
||||
@ -440,13 +450,16 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Cached computation for adding two scalar F32.
|
||||
HloComputation* scalar_add_computation_ = nullptr;
|
||||
|
||||
AlgebraicSimplifier* simplifier_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
bool AlgebraicSimplifierVisitor::Run(
|
||||
HloComputation* computation, const AlgebraicSimplifierOptions& options) {
|
||||
AlgebraicSimplifierVisitor visitor(computation, options);
|
||||
bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
|
||||
const AlgebraicSimplifierOptions& options,
|
||||
AlgebraicSimplifier* simplifier) {
|
||||
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
@ -713,6 +726,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
|
||||
new_slice_shape.set_dimensions(
|
||||
concatenate_dimension,
|
||||
slice_end - operands[i]->slice_starts(concatenate_dimension));
|
||||
simplifier_->UpdateLayout(&new_slice_shape);
|
||||
auto new_limit_indices = operands[i]->slice_limits();
|
||||
new_limit_indices[concatenate_dimension] = slice_end;
|
||||
auto new_slice_op =
|
||||
@ -775,18 +789,19 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
|
||||
}
|
||||
|
||||
static HloInstruction* BuildTupleConstant(HloComputation* computation,
|
||||
const LiteralSlice& literal) {
|
||||
const LiteralSlice& literal,
|
||||
AlgebraicSimplifier* simplifier) {
|
||||
if (literal.shape().IsTuple()) {
|
||||
std::vector<HloInstruction*> elems;
|
||||
elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
|
||||
elems.push_back(
|
||||
BuildTupleConstant(computation, LiteralSlice(literal, {i})));
|
||||
elems.push_back(BuildTupleConstant(
|
||||
computation, LiteralSlice(literal, {i}), simplifier));
|
||||
}
|
||||
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
|
||||
} else {
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(literal.Clone()));
|
||||
simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -795,7 +810,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
|
||||
// explicit Tuple instructions.
|
||||
if (constant->shape().IsTuple()) {
|
||||
return ReplaceInstruction(
|
||||
constant, BuildTupleConstant(computation_, constant->literal()));
|
||||
constant,
|
||||
BuildTupleConstant(computation_, constant->literal(), simplifier_));
|
||||
}
|
||||
|
||||
if (constant->shape().element_type() == TOKEN) {
|
||||
@ -808,7 +824,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
|
||||
Literal unique_scalar(
|
||||
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
|
||||
HloInstruction* scalar = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(unique_scalar)));
|
||||
simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
|
||||
return ReplaceWithNewInstruction(
|
||||
constant,
|
||||
HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
|
||||
@ -854,8 +870,9 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
||||
HloComputation* computation) {
|
||||
std::unique_ptr<HloInstruction> TryDivideToShift(
|
||||
HloInstruction* divide, HloComputation* computation,
|
||||
AlgebraicSimplifier* simplifier) {
|
||||
HloInstruction *a, *b, *c;
|
||||
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
|
||||
|
||||
@ -872,10 +889,11 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
||||
HloInstruction* zero_like_a = BroadcastZeros(
|
||||
computation, a->shape().element_type(), a->shape().dimensions());
|
||||
|
||||
Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
|
||||
simplifier->UpdateLayout(&changed_shape);
|
||||
auto* dividend_is_negative =
|
||||
computation->AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
|
||||
ComparisonDirection::kLt));
|
||||
changed_shape, a, zero_like_a, ComparisonDirection::kLt));
|
||||
|
||||
auto* negated_dividend = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
|
||||
@ -887,8 +905,8 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
||||
|
||||
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
|
||||
|
||||
auto* shift_amount =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
auto* shift_amount = computation->AddInstruction(
|
||||
simplifier->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
|
||||
if (!ShapeUtil::IsScalar(b->shape())) {
|
||||
shift_amount = computation->AddInstruction(
|
||||
@ -911,8 +929,8 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
||||
uint64 b_value = c->literal().GetFirstElement<T>();
|
||||
if (IsPowerOfTwo(b_value)) {
|
||||
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
|
||||
HloInstruction* shift_amount =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
HloInstruction* shift_amount = computation->AddInstruction(
|
||||
simplifier->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
|
||||
if (!ShapeUtil::IsScalar(b->shape())) {
|
||||
shift_amount = computation->AddInstruction(
|
||||
@ -940,49 +958,49 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
switch (divide->shape().element_type()) {
|
||||
case S8:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<int8>(divide, computation_)) {
|
||||
TryDivideToShift<int8>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case S16:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<int16>(divide, computation_)) {
|
||||
TryDivideToShift<int16>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case S32:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<int32>(divide, computation_)) {
|
||||
TryDivideToShift<int32>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case S64:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<int64>(divide, computation_)) {
|
||||
TryDivideToShift<int64>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U8:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<uint8>(divide, computation_)) {
|
||||
TryDivideToShift<uint8>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U16:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<uint16>(divide, computation_)) {
|
||||
TryDivideToShift<uint16>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U32:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<uint32>(divide, computation_)) {
|
||||
TryDivideToShift<uint32>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U64:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryDivideToShift<uint64>(divide, computation_)) {
|
||||
TryDivideToShift<uint64>(divide, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(divide, std::move(shift));
|
||||
}
|
||||
break;
|
||||
@ -1084,7 +1102,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
return Status::OK();
|
||||
}
|
||||
auto inverse = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant((new_literal.Clone())));
|
||||
simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone())));
|
||||
TF_ASSIGN_OR_RETURN(auto new_divide,
|
||||
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
|
||||
return ReplaceInstruction(divide, new_divide);
|
||||
@ -1456,6 +1474,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
|
||||
int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
|
||||
Shape rhs_slice_shape(rhs->shape());
|
||||
rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
|
||||
simplifier_->UpdateLayout(&rhs_slice_shape);
|
||||
|
||||
std::array<int64, 2> start_indices;
|
||||
start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
|
||||
@ -1591,6 +1610,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
|
||||
right_operand->shape().dimensions(1 - rhs_contracting_dimension);
|
||||
auto memoized_shape =
|
||||
ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
|
||||
simplifier_->UpdateLayout(&memoized_shape);
|
||||
auto* memoized_inst = computation_->AddInstruction(
|
||||
HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
|
||||
dnums, dot->precision_config()));
|
||||
@ -1605,7 +1625,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
|
||||
// Slice out start and 0 components and reorder if necessary.
|
||||
auto indices_type = dynamic_slice->operand(1)->shape().element_type();
|
||||
Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
|
||||
simplifier_->UpdateLayout(&s_shape);
|
||||
Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
|
||||
simplifier_->UpdateLayout(&d_shape);
|
||||
HloInstruction* non_zero_start =
|
||||
dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
|
||||
HloInstruction* zero_start =
|
||||
@ -1638,7 +1660,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||
if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
|
||||
ShapeUtil::IsZeroElementArray(lhs->shape()) ||
|
||||
ShapeUtil::IsZeroElementArray(rhs->shape())) {
|
||||
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
auto zero = computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::Zero(dot->shape().element_type())));
|
||||
return ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
|
||||
@ -2183,7 +2206,8 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
|
||||
// zero.
|
||||
auto* iota = Cast<HloIotaInstruction>(instruction);
|
||||
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
|
||||
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
auto zero = computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::Zero(iota->shape().element_type()).Clone()));
|
||||
return ReplaceWithNewInstruction(
|
||||
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
|
||||
@ -2307,7 +2331,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
HloInstruction *lhs, *rhs;
|
||||
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
|
||||
if (IsAll(rhs, 0)) {
|
||||
auto one = HloInstruction::CreateConstant(
|
||||
auto one = simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::One(power->shape().element_type()).Clone());
|
||||
std::unique_ptr<HloInstruction> ones;
|
||||
if (ShapeUtil::IsScalar(power->shape())) {
|
||||
@ -2342,7 +2366,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
|
||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||
if (IsAll(rhs, -1)) {
|
||||
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
auto* one = computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::One(rhs->shape().element_type()).Clone()));
|
||||
|
||||
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
|
||||
@ -2422,14 +2447,16 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
new_operands.reserve(user->operand_count());
|
||||
|
||||
Shape changed_shape;
|
||||
for (HloInstruction* user_operand : user->operands()) {
|
||||
if (user_operand->opcode() == HloOpcode::kBroadcast &&
|
||||
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
|
||||
changed_shape = ShapeUtil::ChangeElementType(
|
||||
operand->shape(), user_operand->shape().element_type());
|
||||
simplifier_->UpdateLayout(&changed_shape);
|
||||
new_operands.push_back(
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::ChangeElementType(
|
||||
operand->shape(), user_operand->shape().element_type()),
|
||||
user_operand->mutable_operand(0), {})));
|
||||
changed_shape, user_operand->mutable_operand(0), {})));
|
||||
} else {
|
||||
CHECK_EQ(broadcast, user_operand);
|
||||
new_operands.push_back(operand);
|
||||
@ -2438,11 +2465,11 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
VLOG(4) << "Sinking broadcast after user:";
|
||||
VLOG(4) << " old broadcast: " << broadcast->ToString();
|
||||
VLOG(4) << " old user: " << user->ToString();
|
||||
HloInstruction* new_user =
|
||||
computation_->AddInstruction(user->CloneWithNewOperands(
|
||||
ShapeUtil::ChangeElementType(operand->shape(),
|
||||
user->shape().element_type()),
|
||||
new_operands));
|
||||
changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
|
||||
user->shape().element_type());
|
||||
simplifier_->UpdateLayout(&changed_shape);
|
||||
HloInstruction* new_user = computation_->AddInstruction(
|
||||
user->CloneWithNewOperands(changed_shape, new_operands));
|
||||
VLOG(4) << " new user: " << new_user->ToString();
|
||||
HloInstruction* new_broadcast =
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
@ -2456,8 +2483,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
|
||||
HloComputation* computation) {
|
||||
std::unique_ptr<HloInstruction> TryRemainderToAnd(
|
||||
HloInstruction* remainder, HloComputation* computation,
|
||||
AlgebraicSimplifier* simplifier) {
|
||||
HloInstruction *a, *b, *c;
|
||||
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
|
||||
|
||||
@ -2487,8 +2515,8 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
|
||||
a->shape(), HloOpcode::kSelect, dividend_is_negative,
|
||||
negated_dividend, a));
|
||||
|
||||
auto* mask_amount =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
auto* mask_amount = computation->AddInstruction(
|
||||
simplifier->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::CreateR0<T>(b_value - 1)));
|
||||
if (!ShapeUtil::IsScalar(b->shape())) {
|
||||
mask_amount = computation->AddInstruction(
|
||||
@ -2509,8 +2537,8 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
|
||||
} else {
|
||||
uint64 b_value = c->literal().GetFirstElement<T>();
|
||||
if (IsPowerOfTwo(b_value)) {
|
||||
HloInstruction* mask_amount =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
HloInstruction* mask_amount = computation->AddInstruction(
|
||||
simplifier->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::CreateR0<T>(b_value - 1)));
|
||||
if (!ShapeUtil::IsScalar(b->shape())) {
|
||||
mask_amount = computation->AddInstruction(
|
||||
@ -2532,49 +2560,49 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
||||
switch (remainder->shape().element_type()) {
|
||||
case S8:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<int8>(remainder, computation_)) {
|
||||
TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case S16:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<int16>(remainder, computation_)) {
|
||||
TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case S32:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<int32>(remainder, computation_)) {
|
||||
TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case S64:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<int64>(remainder, computation_)) {
|
||||
TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U8:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<uint8>(remainder, computation_)) {
|
||||
TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U16:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<uint16>(remainder, computation_)) {
|
||||
TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U32:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<uint32>(remainder, computation_)) {
|
||||
TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
case U64:
|
||||
if (std::unique_ptr<HloInstruction> shift =
|
||||
TryRemainderToAnd<uint64>(remainder, computation_)) {
|
||||
TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) {
|
||||
return ReplaceWithNewInstruction(remainder, std::move(shift));
|
||||
}
|
||||
break;
|
||||
@ -2597,7 +2625,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
|
||||
if (!LayoutUtil::HasLayout(reshaped_shape)) {
|
||||
LayoutUtil::SetToDefaultLayout(&reshaped_shape);
|
||||
}
|
||||
auto empty_constant = HloInstruction::CreateConstant(
|
||||
auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
|
||||
Literal::CreateFromShape(reshaped_shape));
|
||||
|
||||
return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
|
||||
@ -2810,6 +2838,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
|
||||
new_slice_limits),
|
||||
new_slice_operand, new_slice_starts, new_slice_limits,
|
||||
new_slice_stides));
|
||||
simplifier_->UpdateLayout(new_slice->mutable_shape());
|
||||
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
|
||||
slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
|
||||
return true;
|
||||
@ -2914,8 +2943,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
// representative.
|
||||
auto arg = reduce->inputs()[0];
|
||||
auto init_value = reduce->init_values()[0];
|
||||
const Shape& reduce_result_shape =
|
||||
multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
|
||||
Shape& reduce_result_shape = const_cast<Shape&>(
|
||||
multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape());
|
||||
|
||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||
HloComputation* function = reduce->to_apply();
|
||||
@ -3136,6 +3165,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
return !absl::c_linear_search(effective_reduce_dims, dim);
|
||||
},
|
||||
reduce_window->shape());
|
||||
simplifier_->UpdateLayout(&reduce_shape);
|
||||
HloInstruction* reduce =
|
||||
computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
/*shape=*/reduce_shape,
|
||||
@ -3261,11 +3291,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
|
||||
HloInstruction* new_reduce_window_operand;
|
||||
if (convert != nullptr) {
|
||||
new_reduce_window_operand =
|
||||
computation_->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(pad_operand->shape(),
|
||||
convert->shape().element_type()),
|
||||
pad_operand));
|
||||
Shape changed_shape = ShapeUtil::ChangeElementType(
|
||||
pad_operand->shape(), convert->shape().element_type());
|
||||
simplifier_->UpdateLayout(&changed_shape);
|
||||
new_reduce_window_operand = computation_->AddInstruction(
|
||||
HloInstruction::CreateConvert(changed_shape, pad_operand));
|
||||
} else {
|
||||
new_reduce_window_operand = pad_operand;
|
||||
}
|
||||
@ -3614,15 +3644,18 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
|
||||
|
||||
// We already checked feature_dimension is most minor, so data in input_shape
|
||||
// and row-major {conv_width,input_channels} are bitwise identical.
|
||||
const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
input_shape.element_type(), {conv_width, input_channels});
|
||||
simplifier_->UpdateLayout(&new_input_shape);
|
||||
// We already checked input_feature_dimension is more major than
|
||||
// output_feature_dimension, so data in filter_shape and row-major
|
||||
// {input_channels,output_channels} are bitwise identical.
|
||||
const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
filter_shape.element_type(), {input_channels, output_channels});
|
||||
const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
simplifier_->UpdateLayout(&new_filter_shape);
|
||||
Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
convolution_shape.element_type(), {conv_width, output_channels});
|
||||
simplifier_->UpdateLayout(&dot_output_shape);
|
||||
|
||||
auto new_lhs = add_bitcast(new_input_shape, lhs);
|
||||
auto new_rhs = add_bitcast(new_filter_shape, rhs);
|
||||
@ -3647,7 +3680,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
convolution,
|
||||
HloInstruction::CreateBroadcast(
|
||||
convolution->shape(),
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::Zero(convolution->shape().element_type()))),
|
||||
{}));
|
||||
}
|
||||
@ -3731,7 +3765,7 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
|
||||
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
if (AlgebraicSimplifierVisitor::Run(comp, options_)) {
|
||||
if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -105,6 +105,15 @@ class AlgebraicSimplifier : public HloModulePass {
|
||||
// computation was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Create constant from literal with tiles and element size updated in the
|
||||
// constant's layout.
|
||||
std::unique_ptr<HloInstruction> CreateConstantWithLayoutUpdated(
|
||||
Literal literal) {
|
||||
auto constant = HloInstruction::CreateConstant(std::move(literal));
|
||||
UpdateLayout(constant->mutable_shape());
|
||||
return constant;
|
||||
}
|
||||
|
||||
private:
|
||||
AlgebraicSimplifierOptions options_;
|
||||
};
|
||||
|
@ -29,8 +29,11 @@ namespace xla {
|
||||
class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit BFloat16ConversionFoldingVisitor(
|
||||
HloComputation* computation, const BFloat16Support* bfloat16_support)
|
||||
: computation_(computation), bfloat16_support_(bfloat16_support) {}
|
||||
HloComputation* computation, const BFloat16Support* bfloat16_support,
|
||||
BFloat16ConversionFolding* bfloat16_conversion_folding)
|
||||
: computation_(computation),
|
||||
bfloat16_support_(bfloat16_support),
|
||||
bfloat16_conversion_folding_(bfloat16_conversion_folding) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
@ -38,8 +41,10 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
|
||||
static bool Run(HloComputation* computation,
|
||||
const BFloat16Support* bfloat16_support) {
|
||||
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
|
||||
const BFloat16Support* bfloat16_support,
|
||||
BFloat16ConversionFolding* bfloat16_conversion_folding) {
|
||||
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support,
|
||||
bfloat16_conversion_folding);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
@ -61,6 +66,7 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
HloComputation* computation_;
|
||||
const BFloat16Support* bfloat16_support_;
|
||||
BFloat16ConversionFolding* bfloat16_conversion_folding_;
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
@ -68,6 +74,7 @@ Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
|
||||
HloInstruction* hlo) {
|
||||
std::vector<HloInstruction*> materialized_users = hlo->users();
|
||||
hlo->mutable_shape()->set_element_type(BF16);
|
||||
bfloat16_conversion_folding_->UpdateLayout(hlo->mutable_shape());
|
||||
for (auto user : materialized_users) {
|
||||
CHECK_EQ(user->opcode(), HloOpcode::kConvert);
|
||||
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
|
||||
@ -228,6 +235,8 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
|
||||
|
||||
ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})
|
||||
->set_element_type(BF16);
|
||||
bfloat16_conversion_folding_->UpdateLayout(
|
||||
ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}));
|
||||
for (auto gte : per_tuple_element_gtes[i]) {
|
||||
TF_RETURN_IF_ERROR(FoldOutputConversions(gte));
|
||||
}
|
||||
@ -241,7 +250,7 @@ StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
|
||||
2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
|
||||
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_, this)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -29,15 +29,20 @@ namespace xla {
|
||||
|
||||
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit BFloat16NormalizationVisitor(HloComputation* computation,
|
||||
const BFloat16Support* bfloat16_support)
|
||||
: computation_(computation), bfloat16_support_(bfloat16_support) {}
|
||||
explicit BFloat16NormalizationVisitor(
|
||||
HloComputation* computation, const BFloat16Support* bfloat16_support,
|
||||
BFloat16Normalization* bfloat16_normalization)
|
||||
: computation_(computation),
|
||||
bfloat16_support_(bfloat16_support),
|
||||
bfloat16_normalization_(bfloat16_normalization) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
static bool Run(HloComputation* computation,
|
||||
const BFloat16Support* bfloat16_support) {
|
||||
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
|
||||
const BFloat16Support* bfloat16_support,
|
||||
BFloat16Normalization* bfloat16_normalization) {
|
||||
BFloat16NormalizationVisitor visitor(computation, bfloat16_support,
|
||||
bfloat16_normalization);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
@ -73,6 +78,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
HloComputation* computation_;
|
||||
const BFloat16Support* bfloat16_support_;
|
||||
BFloat16Normalization* bfloat16_normalization_;
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
@ -95,6 +101,7 @@ Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
|
||||
computation->set_root_instruction(convert);
|
||||
}
|
||||
convert->mutable_shape()->set_element_type(to);
|
||||
bfloat16_normalization_->UpdateLayout(convert->mutable_shape());
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -103,6 +110,7 @@ Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
|
||||
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
|
||||
auto original_type = hlo->shape().element_type();
|
||||
hlo->mutable_shape()->set_element_type(to);
|
||||
bfloat16_normalization_->UpdateLayout(hlo->mutable_shape());
|
||||
return InsertConvertAfterOutput(hlo, original_type, computation);
|
||||
}
|
||||
|
||||
@ -110,8 +118,10 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
|
||||
HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
|
||||
HloComputation* computation) {
|
||||
auto operand = hlo->mutable_operand(operand_idx);
|
||||
auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(operand->shape(), to), operand));
|
||||
auto shape = ShapeUtil::ChangeElementType(operand->shape(), to);
|
||||
bfloat16_normalization_->UpdateLayout(&shape);
|
||||
auto convert = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(shape, operand));
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
@ -243,11 +253,13 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
|
||||
continue;
|
||||
}
|
||||
subshape->set_element_type(F32);
|
||||
bfloat16_normalization_->UpdateLayout(subshape);
|
||||
auto gte = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
|
||||
auto shape = ShapeUtil::ChangeElementType(*subshape, BF16);
|
||||
bfloat16_normalization_->UpdateLayout(&shape);
|
||||
output_elements[i] =
|
||||
computation_->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(*subshape, BF16), gte));
|
||||
computation_->AddInstruction(HloInstruction::CreateConvert(shape, gte));
|
||||
}
|
||||
auto tuple = computation_->AddInstruction(
|
||||
HloInstruction::CreateTuple(output_elements));
|
||||
@ -401,7 +413,7 @@ StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
|
||||
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
|
||||
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_, this)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -674,11 +674,13 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
|
||||
if (hlo->opcode() != HloOpcode::kConstant) {
|
||||
continue;
|
||||
}
|
||||
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
|
||||
if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
|
||||
hlo->shape())) {
|
||||
TF_ASSIGN_OR_RETURN(auto converted_literal,
|
||||
hlo->literal().ConvertToShape(hlo->shape()));
|
||||
auto new_constant = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(converted_literal)));
|
||||
UpdateLayout(new_constant->mutable_shape());
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
|
||||
}
|
||||
}
|
||||
@ -797,6 +799,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
|
||||
auto subshape = entry.first;
|
||||
CHECK_EQ(subshape->element_type(), F32);
|
||||
subshape->set_element_type(BF16);
|
||||
UpdateLayout(subshape);
|
||||
changed_ = true;
|
||||
}
|
||||
}
|
||||
|
@ -716,7 +716,8 @@ ProgramShape HloComputation::ComputeProgramShape() const {
|
||||
return program_shape;
|
||||
}
|
||||
|
||||
bool HloComputation::operator==(const HloComputation& other) const {
|
||||
bool HloComputation::Equal(const HloComputation& other,
|
||||
bool is_layout_sensitive) const {
|
||||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
@ -741,7 +742,8 @@ bool HloComputation::operator==(const HloComputation& other) const {
|
||||
[](const HloInstruction*, const HloInstruction*) { return true; },
|
||||
[](const HloComputation* a, const HloComputation* b) {
|
||||
return *a == *b;
|
||||
});
|
||||
},
|
||||
is_layout_sensitive);
|
||||
if (!identical_ignoring_operands) {
|
||||
return false;
|
||||
}
|
||||
|
@ -270,7 +270,12 @@ class HloComputation {
|
||||
ProgramShape ComputeProgramShape() const;
|
||||
|
||||
// Return whether `*this` and `other` are functionally equivalent.
|
||||
bool operator==(const HloComputation& other) const;
|
||||
bool Equal(const HloComputation& other, bool is_layout_sensitive) const;
|
||||
|
||||
// Return whether `*this` and `other` are functionally equivalent.
|
||||
bool operator==(const HloComputation& other) const {
|
||||
return Equal(other, true);
|
||||
}
|
||||
|
||||
// Replaces old instruction with newly created instruction. Removes old
|
||||
// instruction from computation. Updates uses and root instruction.
|
||||
|
@ -1804,8 +1804,9 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
|
||||
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
|
||||
// Out of convenience the literal may have been produced with a different
|
||||
// layout. Relayout as indicated by the HLO instruction.
|
||||
if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
|
||||
hlo->shape())) {
|
||||
if (!Layout::Equal().MinorToMajorOnly()(
|
||||
GetEvaluatedLiteralFor(hlo).shape().layout(),
|
||||
hlo->shape().layout())) {
|
||||
evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -303,6 +303,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
instruction = CreateConstant(std::move(literal));
|
||||
// Literal's shape may have no/different tiling info.
|
||||
CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(instruction->shape(),
|
||||
shape));
|
||||
*instruction->mutable_shape() = shape;
|
||||
} else {
|
||||
instruction = absl::make_unique<HloConstantInstruction>(shape);
|
||||
}
|
||||
|
@ -1018,6 +1018,11 @@ HloConstantInstruction::HloConstantInstruction(Literal literal)
|
||||
: HloInstruction(HloOpcode::kConstant, literal.shape()),
|
||||
literal_(std::move(literal)) {}
|
||||
|
||||
HloConstantInstruction::HloConstantInstruction(Literal literal,
|
||||
const Shape& shape)
|
||||
: HloInstruction(HloOpcode::kConstant, shape),
|
||||
literal_(std::move(literal)) {}
|
||||
|
||||
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
|
||||
: HloInstruction(HloOpcode::kConstant, shape) {}
|
||||
|
||||
@ -1063,7 +1068,12 @@ HloConstantInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
CHECK(literal_.has_value());
|
||||
return absl::make_unique<HloConstantInstruction>(literal_->Clone());
|
||||
// Literal's shape may have no/different tiling info. Use this instruction's
|
||||
// shape instead.
|
||||
CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
|
||||
this->shape()));
|
||||
return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
|
||||
this->shape());
|
||||
}
|
||||
|
||||
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
||||
|
@ -650,6 +650,7 @@ class HloSliceInstruction : public HloInstruction {
|
||||
class HloConstantInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloConstantInstruction(Literal literal);
|
||||
explicit HloConstantInstruction(Literal literal, const Shape& shape);
|
||||
// Used when the literal is too large and dropped.
|
||||
explicit HloConstantInstruction(const Shape& shape);
|
||||
// Returns the literal associated with this instruction.
|
||||
|
@ -56,6 +56,14 @@ class HloModulePass : public HloPassInterface {
|
||||
}
|
||||
return changed;
|
||||
};
|
||||
|
||||
// Update the layout of a Shape to one that is supported by a given backend.
|
||||
// One can call this function after modifying the Shape in case that modifying
|
||||
// the Shape requires changes to the layout for the given Backend.
|
||||
//
|
||||
// TODO(b/129084868): Make this Backend dependent instead of requiring
|
||||
// deriving from the pass the and overriding this function.
|
||||
virtual void UpdateLayout(Shape* shape) {}
|
||||
};
|
||||
|
||||
// Base class for passes which are module-group scoped. These passes cannot run
|
||||
|
@ -343,7 +343,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
|
||||
|
||||
// Check that the 'compare' computation returns a PRED.
|
||||
Shape compare_shape = compare->root_instruction()->shape();
|
||||
if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
|
||||
if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
|
||||
return InternalError(
|
||||
"The Sort compare computation shape does not lead to a scalar "
|
||||
"predicate shape: %s",
|
||||
@ -393,7 +393,8 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
|
||||
return InternalError("Constant is required to have a valid literal: %s",
|
||||
constant->ToString());
|
||||
}
|
||||
return CheckShape(constant, constant->literal().shape());
|
||||
return CheckShape(constant, constant->literal().shape(),
|
||||
/*only_compare_minor_to_major_in_layout=*/true);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
|
||||
@ -654,7 +655,8 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
|
||||
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
|
||||
const Shape& conditional_shape =
|
||||
xla_while->while_condition()->root_instruction()->shape();
|
||||
if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
|
||||
if (!ShapeUtil::Compatible(conditional_shape,
|
||||
ShapeUtil::MakeShape(PRED, {}))) {
|
||||
return InternalError(
|
||||
"Conditional computation shape does not lead to a scalar predicate "
|
||||
"shape: %s",
|
||||
@ -696,7 +698,8 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) {
|
||||
return CheckShape(send,
|
||||
ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
|
||||
ShapeUtil::MakeShape(U32, {}),
|
||||
ShapeUtil::MakeTokenShape()}));
|
||||
ShapeUtil::MakeTokenShape()}),
|
||||
/*only_compare_minor_to_major_in_layout=*/true);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
|
||||
@ -705,9 +708,11 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
|
||||
|
||||
Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
|
||||
return CheckShape(
|
||||
recv, ShapeUtil::MakeTupleShape(
|
||||
recv,
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::GetTupleElementShape(recv->shape(), 0),
|
||||
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}));
|
||||
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
|
||||
/*only_compare_minor_to_major_in_layout=*/true);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
|
||||
@ -844,7 +849,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
||||
const Shape& inferred_shape) {
|
||||
const Shape& inferred_shape,
|
||||
bool only_compare_minor_to_major_in_layout) {
|
||||
// If allow_mixed_precision_ is false, check if there are operands with
|
||||
// different precisions. We need this check because ShapeInference allows
|
||||
// mixed precision inputs.
|
||||
@ -878,7 +884,8 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kTupleSelect:
|
||||
case HloOpcode::kWhile:
|
||||
return ShapesSame(instruction->shape(), inferred_shape);
|
||||
return ShapesSame(instruction->shape(), inferred_shape,
|
||||
only_compare_minor_to_major_in_layout);
|
||||
|
||||
// We allow arbitrary layout and f32->bf16 transformations on all other
|
||||
// instructions, although this may be made more strict pending discussion
|
||||
|
@ -106,7 +106,8 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
// Check the instruction's shape against the shape given by ShapeInference
|
||||
// and return an appropriate error if there is a mismatch.
|
||||
Status CheckShape(const HloInstruction* instruction,
|
||||
const Shape& inferred_shape);
|
||||
const Shape& inferred_shape,
|
||||
bool only_compare_minor_to_major_in_layout = false);
|
||||
|
||||
// Overload which takes a StatusOr to reduce boilerplate in the caller.
|
||||
Status CheckShape(const HloInstruction* instruction,
|
||||
@ -120,14 +121,31 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
|
||||
private:
|
||||
// Helpers that switch on layout_sensitive_.
|
||||
bool ShapesSame(const Shape& a, const Shape& b) {
|
||||
return layout_sensitive_ ? ShapeUtil::Equal(a, b)
|
||||
: ShapeUtil::Compatible(a, b);
|
||||
bool ShapesSame(const Shape& a, const Shape& b,
|
||||
bool minor_to_major_only = false) {
|
||||
if (!layout_sensitive_) {
|
||||
return ShapeUtil::Compatible(a, b);
|
||||
}
|
||||
bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
|
||||
return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
|
||||
: ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
|
||||
Shape::Equal equal;
|
||||
if (minor_to_major_only) {
|
||||
equal.MinorToMajorOnlyInLayout();
|
||||
}
|
||||
return equal(a, b);
|
||||
}
|
||||
|
||||
bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b,
|
||||
bool minor_to_major_only = false) {
|
||||
if (!layout_sensitive_) {
|
||||
return ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
|
||||
}
|
||||
Shape::Equal equal;
|
||||
if (minor_to_major_only) {
|
||||
equal.MinorToMajorOnlyInLayout();
|
||||
}
|
||||
equal.IgnoreFpPrecision();
|
||||
return equal(a, b);
|
||||
}
|
||||
|
||||
string StringifyShape(const Shape& s) {
|
||||
return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
|
||||
: ShapeUtil::HumanString(s);
|
||||
|
@ -173,7 +173,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
|
||||
auto iter = buffer_constraints_.find(&buffer);
|
||||
if (iter != buffer_constraints_.end()) {
|
||||
const BufferLayoutConstraint& curr_constraint = iter->second;
|
||||
if (LayoutUtil::Equal(curr_constraint.layout(), layout)) {
|
||||
if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
|
||||
// New constraint matches existing constraint. Nothing to do.
|
||||
return Status::OK();
|
||||
}
|
||||
@ -210,7 +210,7 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
|
||||
GetOperandLayoutConstraint(instruction, operand_no);
|
||||
if (curr_shape_layout != nullptr) {
|
||||
if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
|
||||
shape_with_layout)) {
|
||||
shape_with_layout, /*minor_to_major_only=*/true)) {
|
||||
// New constraint matches existing constraint. Nothing to do.
|
||||
return Status::OK();
|
||||
}
|
||||
@ -269,7 +269,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
|
||||
|
||||
const ShapeLayout* curr_shape_layout = ResultLayout();
|
||||
if (curr_shape_layout != nullptr) {
|
||||
if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
|
||||
if (!curr_shape_layout->MatchesLayoutInShape(
|
||||
shape_with_layout, /*minor_to_major_only=*/true)) {
|
||||
return FailedPrecondition(
|
||||
"Result of computation %s already has the layout constraint %s, "
|
||||
"cannot add incompatible constraint %s",
|
||||
@ -647,6 +648,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
|
||||
namespace {
|
||||
|
||||
bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
|
||||
return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
|
||||
}
|
||||
|
||||
// The operands of a call must match the layouts of parameters in the
|
||||
// ComputationLayout, and the call instruction itself must match the result
|
||||
// layout in the ComputationLayout.
|
||||
@ -656,10 +661,10 @@ Status CheckCallLayout(HloInstruction* call,
|
||||
TF_RET_CHECK(computation->num_parameters() == call->operand_count());
|
||||
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
||||
TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
|
||||
call->operand(i)->shape()));
|
||||
call->operand(i)->shape(), /*minor_to_major_only=*/true));
|
||||
}
|
||||
TF_RET_CHECK(
|
||||
computation_layout.result_layout().MatchesLayoutInShape(call->shape()));
|
||||
TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
|
||||
call->shape(), /*minor_to_major_only=*/true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -670,8 +675,8 @@ Status CheckCustomCallLayout(HloInstruction* instruction) {
|
||||
const HloCustomCallInstruction* custom_call =
|
||||
DynCast<HloCustomCallInstruction>(instruction);
|
||||
for (int64 i = 0; i < custom_call->operand_count(); ++i) {
|
||||
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
|
||||
custom_call->operand(i)->shape(),
|
||||
TF_RET_CHECK(
|
||||
LayoutsInShapesEqual(custom_call->operand(i)->shape(),
|
||||
custom_call->operand_shapes_with_layout()[i]));
|
||||
}
|
||||
}
|
||||
@ -690,13 +695,12 @@ Status CheckWhileLayout(HloInstruction* while_inst,
|
||||
auto init_shape = while_inst->operand(0)->shape();
|
||||
TF_RET_CHECK(
|
||||
condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
|
||||
init_shape));
|
||||
init_shape, /*minor_to_major_only=*/true));
|
||||
TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
|
||||
init_shape));
|
||||
TF_RET_CHECK(
|
||||
body_computation_layout.result_layout().MatchesLayoutInShape(init_shape));
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
|
||||
init_shape, /*minor_to_major_only=*/true));
|
||||
TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
|
||||
init_shape, /*minor_to_major_only=*/true));
|
||||
TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -709,13 +713,14 @@ Status CheckConditionalLayout(
|
||||
branch_computation_layouts[j].result_layout());
|
||||
TF_RET_CHECK(
|
||||
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
|
||||
instruction->shape()));
|
||||
instruction->shape(), /*minor_to_major_only=*/true));
|
||||
TF_RET_CHECK(
|
||||
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
|
||||
instruction->branch_computation(j)->root_instruction()->shape()));
|
||||
instruction->branch_computation(j)->root_instruction()->shape(),
|
||||
/*minor_to_major_only=*/true));
|
||||
TF_RET_CHECK(
|
||||
branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
|
||||
branch_operand->shape()));
|
||||
branch_operand->shape(), /*minor_to_major_only=*/true));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -726,11 +731,11 @@ Status CheckConditionalLayout(
|
||||
Status CheckFusionLayout(HloInstruction* fusion) {
|
||||
TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
|
||||
|
||||
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
|
||||
fusion->shape(), fusion->fused_expression_root()->shape()));
|
||||
TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
|
||||
fusion->fused_expression_root()->shape()));
|
||||
for (int64 i = 0; i < fusion->operand_count(); ++i) {
|
||||
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
|
||||
fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape()));
|
||||
TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
|
||||
fusion->operand(i)->shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -742,7 +747,8 @@ Status CheckParameterLayout(HloInstruction* parameter,
|
||||
const ShapeLayout& parameter_layout =
|
||||
computation_layout.parameter_layout(parameter->parameter_number());
|
||||
if (parameter_layout.LayoutIsSet() &&
|
||||
!parameter_layout.MatchesLayoutInShape(parameter->shape())) {
|
||||
!parameter_layout.MatchesLayoutInShape(parameter->shape(),
|
||||
/*minor_to_major_only=*/true)) {
|
||||
return InternalError(
|
||||
"parameter instruction %s does not match layout of computation "
|
||||
"shape: %s",
|
||||
@ -753,8 +759,7 @@ Status CheckParameterLayout(HloInstruction* parameter,
|
||||
|
||||
// The layout of a constant instruction must match the layout of its literal.
|
||||
Status CheckConstantLayout(HloInstruction* constant) {
|
||||
if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(),
|
||||
constant->shape())) {
|
||||
if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
|
||||
return InternalError(
|
||||
"constant instruction %s does not match the layout of its literal %s",
|
||||
constant->ToString(),
|
||||
@ -785,7 +790,8 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
|
||||
HloInstruction* gte = instruction->parent()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
|
||||
|
||||
if (ShapeUtil::Equal(target_shape, instr_shape)) {
|
||||
if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
|
||||
instr_shape)) {
|
||||
// Shapes and layouts are equal, no need to copy.
|
||||
element_copies.push_back(gte);
|
||||
} else {
|
||||
@ -831,7 +837,8 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
|
||||
TF_RET_CHECK(operand_layout.LayoutIsSet());
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
|
||||
|
||||
if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
|
||||
if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
|
||||
operand->shape())) {
|
||||
VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
|
||||
<< instruction->ToString();
|
||||
// Operand layout already matches our constraint. Nothing to do.
|
||||
@ -892,7 +899,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
|
||||
const Shape& instruction_subshape =
|
||||
ShapeUtil::GetSubshape(instruction->shape(), index);
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) {
|
||||
if (!Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
instruction_subshape, buffer->shape())) {
|
||||
return InternalError(
|
||||
"Layout of instruction %s at index {%s} does not match "
|
||||
"source LogicalBuffer %s: %s vs %s",
|
||||
@ -954,8 +962,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
|
||||
FindOrDie(computation_layouts_, module->entry_computation())
|
||||
.result_layout();
|
||||
if (result_layout.LayoutIsSet()) {
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Equal(module->result_shape(), result_layout.shape()));
|
||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
module->result_shape(), result_layout.shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1510,8 +1518,8 @@ StatusOr<Layout> InferArrayLayout(
|
||||
|
||||
if (first_buffer_layout == nullptr) {
|
||||
first_buffer_layout = &source_buffer->shape().layout();
|
||||
} else if (!LayoutUtil::Equal(source_buffer->shape().layout(),
|
||||
*first_buffer_layout)) {
|
||||
} else if (!Layout::Equal().MinorToMajorOnly()(
|
||||
source_buffer->shape().layout(), *first_buffer_layout)) {
|
||||
// The points-to set is ambiguous for this index and the different source
|
||||
// buffers have different layouts. This case is possible in valid XLA
|
||||
// computations because we do not propagate BufferLayoutConstraints to all
|
||||
@ -1789,7 +1797,8 @@ Status LayoutAssignment::RunOnComputation(
|
||||
// layout constraint.
|
||||
if (constraints.ResultLayout() != nullptr &&
|
||||
!constraints.ResultLayout()->MatchesLayoutInShape(
|
||||
computation->root_instruction()->shape())) {
|
||||
computation->root_instruction()->shape(),
|
||||
/*minor_to_major_only=*/true)) {
|
||||
if (conditional_mismatch_.count(computation) > 0) {
|
||||
*FindOrDie(computation_layouts_, computation).mutable_result_layout() =
|
||||
FindOrDie(conditional_mismatch_, computation).result_layout();
|
||||
@ -1907,7 +1916,9 @@ Status LayoutAssignment::PropagateComputationLayouts(
|
||||
<< ": " << computed_computation_layout.result_layout().ToString();
|
||||
*result_layout = computed_computation_layout.result_layout();
|
||||
} else {
|
||||
TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout);
|
||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
computed_computation_layout.result_layout().shape(),
|
||||
result_layout->shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -141,6 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ignore_layout_) {
|
||||
if (lhs.layout().format() != rhs.layout().format()) {
|
||||
VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
|
||||
@ -161,11 +166,6 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ignore_dynamic_dimension_) {
|
||||
for (int i = 0; i < lhs.rank(); ++i) {
|
||||
if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) {
|
||||
|
@ -169,6 +169,11 @@ class Shape {
|
||||
ignore_element_size_in_layout_ = true;
|
||||
return *this;
|
||||
}
|
||||
Equal& MinorToMajorOnlyInLayout() {
|
||||
ignore_tiles_in_layout_ = true;
|
||||
ignore_element_size_in_layout_ = true;
|
||||
return *this;
|
||||
}
|
||||
Equal& IgnoreElementType() {
|
||||
ignore_element_type_ = true;
|
||||
return *this;
|
||||
|
@ -46,8 +46,13 @@ void ShapeLayout::SetToDefaultLayout() {
|
||||
LayoutUtil::SetToDefaultLayout(&shape_);
|
||||
}
|
||||
|
||||
bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const {
|
||||
return ShapeUtil::Equal(shape, shape_);
|
||||
bool ShapeLayout::MatchesLayoutInShape(const Shape& shape,
|
||||
bool minor_to_major_only) const {
|
||||
auto equal = Shape::Equal();
|
||||
if (minor_to_major_only) {
|
||||
equal.MinorToMajorOnlyInLayout();
|
||||
}
|
||||
return equal(shape, shape_);
|
||||
}
|
||||
|
||||
const Layout& ShapeLayout::layout() const {
|
||||
|
@ -45,7 +45,8 @@ class ShapeLayout {
|
||||
// Returns true if the Layouts in this ShapeLayout match the layouts in the
|
||||
// given shape. Returns false otherwise. If the given shape is not compatible
|
||||
// with the ShapeLayout's shape, then false is returned.
|
||||
bool MatchesLayoutInShape(const Shape& shape) const;
|
||||
bool MatchesLayoutInShape(const Shape& shape,
|
||||
bool minor_to_major_only = false) const;
|
||||
|
||||
// Copies the layout from the given shape into this ShapeLayout. 'other_shape'
|
||||
// must be compatible with the ShapeLayout's shape.
|
||||
|
@ -102,6 +102,11 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeUtil::MakeValidatedShape(element_type, dimensions));
|
||||
if (element_size_in_bits ==
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
|
||||
// Only set element_size_in_bits if it's different from the default value.
|
||||
element_size_in_bits = 0;
|
||||
}
|
||||
*shape.mutable_layout() =
|
||||
LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits);
|
||||
if (!shape.has_layout()) {
|
||||
@ -219,7 +224,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
||||
dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
|
||||
}
|
||||
return MakeShapeWithDescendingLayout(shape.element_type(), dims);
|
||||
Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
|
||||
// Since the physical layout is kept the same, the tiles and element size are
|
||||
// the same also.
|
||||
*new_shape.mutable_layout()->mutable_tiles() = shape.layout().tiles();
|
||||
new_shape.mutable_layout()->set_element_size_in_bits(
|
||||
shape.layout().element_size_in_bits());
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
/* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
|
||||
|
@ -395,6 +395,8 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
|
||||
ParametricDotTestWithoutLayoutAssignment() {
|
||||
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
|
||||
"layout-assignment");
|
||||
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
|
||||
"tiling-assignment");
|
||||
// Disable algebraic simplification because the pass may replace a dot
|
||||
// instruction with a layout-changing multiplication instruction.
|
||||
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
|
||||
|
@ -846,16 +846,16 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
|
||||
xla::ProgramShape xla_program_shape =
|
||||
XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
|
||||
EXPECT_TRUE(xla::LayoutUtil::Equal(
|
||||
EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
|
||||
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
|
||||
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
|
||||
.layout()));
|
||||
EXPECT_TRUE(xla::LayoutUtil::Equal(
|
||||
EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
|
||||
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
|
||||
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
|
||||
.layout()));
|
||||
EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(),
|
||||
xla_program_shape.result().layout()));
|
||||
EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
|
||||
program_shape.result().layout(), xla_program_shape.result().layout()));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, DotGeneralWithLayoutTest) {
|
||||
|
Loading…
Reference in New Issue
Block a user