Populate tiling info in Layout data.

PiperOrigin-RevId: 240596416
This commit is contained in:
A. Unique TensorFlower 2019-03-27 10:51:52 -07:00 committed by TensorFlower Gardener
parent c510b79d5c
commit f366225ec6
24 changed files with 336 additions and 167 deletions

View File

@ -96,8 +96,13 @@ string Layout::ToString() const {
} }
bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
if (lhs.format() != rhs.format() || if (lhs.format() != rhs.format()) {
lhs.minor_to_major() != rhs.minor_to_major() || return false;
}
if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) {
return false;
}
if (lhs.format() == SPARSE &&
lhs.max_sparse_elements() != rhs.max_sparse_elements()) { lhs.max_sparse_elements() != rhs.max_sparse_elements()) {
return false; return false;
} }

View File

@ -127,6 +127,12 @@ class Layout {
return *this; return *this;
} }
Equal& MinorToMajorOnly() {
ignore_tiles_ = true;
ignore_element_size_ = true;
return *this;
}
private: private:
bool ignore_tiles_ = false; bool ignore_tiles_ = false;
bool ignore_element_size_ = false; bool ignore_element_size_ = false;

View File

@ -250,12 +250,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation. // Runs the visitor on a computation.
static bool Run(HloComputation* computation, static bool Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options); const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier);
private: private:
explicit AlgebraicSimplifierVisitor(HloComputation* computation, explicit AlgebraicSimplifierVisitor(HloComputation* computation,
const AlgebraicSimplifierOptions& options) const AlgebraicSimplifierOptions& options,
: computation_(computation), options_(options) {} AlgebraicSimplifier* simplifier)
: computation_(computation), options_(options), simplifier_(simplifier) {}
// Transforms Dots where at least one input is a vector or has a degenerate // Transforms Dots where at least one input is a vector or has a degenerate
// dimension and converts it into a multiply and reduce. This should enable // dimension and converts it into a multiply and reduce. This should enable
@ -274,10 +276,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
if (hlo->shape().rank() == 1) { if (hlo->shape().rank() == 1) {
return hlo; return hlo;
} }
return computation_->AddInstruction(HloInstruction::CreateReshape( auto hlo_instruction =
computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(hlo->shape().element_type(), ShapeUtil::MakeShape(hlo->shape().element_type(),
{ShapeUtil::ElementsIn(hlo->shape())}), {ShapeUtil::ElementsIn(hlo->shape())}),
hlo)); hlo));
simplifier_->UpdateLayout(hlo_instruction->mutable_shape());
return hlo_instruction;
} }
// Converts to primitive type if the input hlo is not that type, otherwise // Converts to primitive type if the input hlo is not that type, otherwise
@ -287,8 +292,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
if (hlo->shape().element_type() == element_type) { if (hlo->shape().element_type() == element_type) {
return hlo; return hlo;
} }
return computation_->AddInstruction(HloInstruction::CreateConvert( Shape changed_shape =
ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); ShapeUtil::ChangeElementType(hlo->shape(), element_type);
simplifier_->UpdateLayout(&changed_shape);
return computation_->AddInstruction(
HloInstruction::CreateConvert(changed_shape, hlo));
} }
// Transposes a dot operand such that the batch dimensions are the msot major, // Transposes a dot operand such that the batch dimensions are the msot major,
@ -312,13 +320,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Helper method to perform and add reduction on a list of dimensions. // Helper method to perform and add reduction on a list of dimensions.
HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) { HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) {
HloInstruction* zero = HloInstruction* zero = computation_->AddInstruction(
computation_->AddInstruction(HloInstruction::CreateConstant( simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(hlo->shape().element_type()).Clone())); LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::FilterDimensions( Shape shape = ShapeUtil::FilterDimensions(
[&](int64 dim) { return !absl::c_linear_search(dims, dim); }, [&](int64 dim) { return !absl::c_linear_search(dims, dim); },
hlo->shape()); hlo->shape());
simplifier_->UpdateLayout(&shape);
return computation_->AddInstruction(HloInstruction::CreateReduce( return computation_->AddInstruction(HloInstruction::CreateReduce(
shape, hlo, zero, dims, AddReduce_computation)); shape, hlo, zero, dims, AddReduce_computation));
} }
@ -403,6 +412,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloComputation::Builder b("scalar_add_computation"); HloComputation::Builder b("scalar_add_computation");
Shape shape = ShapeUtil::MakeShape(F32, {}); Shape shape = ShapeUtil::MakeShape(F32, {});
simplifier_->UpdateLayout(&shape);
auto scalar_lhs = b.AddInstruction( auto scalar_lhs = b.AddInstruction(
HloInstruction::CreateParameter(0, shape, "scalar_lhs")); HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
auto scalar_rhs = b.AddInstruction( auto scalar_rhs = b.AddInstruction(
@ -440,13 +450,16 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Cached computation for adding two scalar F32. // Cached computation for adding two scalar F32.
HloComputation* scalar_add_computation_ = nullptr; HloComputation* scalar_add_computation_ = nullptr;
AlgebraicSimplifier* simplifier_ = nullptr;
}; };
} // namespace } // namespace
bool AlgebraicSimplifierVisitor::Run( bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
HloComputation* computation, const AlgebraicSimplifierOptions& options) { const AlgebraicSimplifierOptions& options,
AlgebraicSimplifierVisitor visitor(computation, options); AlgebraicSimplifier* simplifier) {
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
TF_CHECK_OK(computation->Accept(&visitor)); TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_; return visitor.changed_;
} }
@ -713,6 +726,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
new_slice_shape.set_dimensions( new_slice_shape.set_dimensions(
concatenate_dimension, concatenate_dimension,
slice_end - operands[i]->slice_starts(concatenate_dimension)); slice_end - operands[i]->slice_starts(concatenate_dimension));
simplifier_->UpdateLayout(&new_slice_shape);
auto new_limit_indices = operands[i]->slice_limits(); auto new_limit_indices = operands[i]->slice_limits();
new_limit_indices[concatenate_dimension] = slice_end; new_limit_indices[concatenate_dimension] = slice_end;
auto new_slice_op = auto new_slice_op =
@ -775,18 +789,19 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
} }
static HloInstruction* BuildTupleConstant(HloComputation* computation, static HloInstruction* BuildTupleConstant(HloComputation* computation,
const LiteralSlice& literal) { const LiteralSlice& literal,
AlgebraicSimplifier* simplifier) {
if (literal.shape().IsTuple()) { if (literal.shape().IsTuple()) {
std::vector<HloInstruction*> elems; std::vector<HloInstruction*> elems;
elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
elems.push_back( elems.push_back(BuildTupleConstant(
BuildTupleConstant(computation, LiteralSlice(literal, {i}))); computation, LiteralSlice(literal, {i}), simplifier));
} }
return computation->AddInstruction(HloInstruction::CreateTuple(elems)); return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else { } else {
return computation->AddInstruction( return computation->AddInstruction(
HloInstruction::CreateConstant(literal.Clone())); simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
} }
} }
@ -795,7 +810,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// explicit Tuple instructions. // explicit Tuple instructions.
if (constant->shape().IsTuple()) { if (constant->shape().IsTuple()) {
return ReplaceInstruction( return ReplaceInstruction(
constant, BuildTupleConstant(computation_, constant->literal())); constant,
BuildTupleConstant(computation_, constant->literal(), simplifier_));
} }
if (constant->shape().element_type() == TOKEN) { if (constant->shape().element_type() == TOKEN) {
@ -808,7 +824,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
Literal unique_scalar( Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal())); LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction( HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar))); simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
return ReplaceWithNewInstruction( return ReplaceWithNewInstruction(
constant, constant,
HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
@ -854,8 +870,9 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) {
} }
template <typename T> template <typename T>
std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide, std::unique_ptr<HloInstruction> TryDivideToShift(
HloComputation* computation) { HloInstruction* divide, HloComputation* computation,
AlgebraicSimplifier* simplifier) {
HloInstruction *a, *b, *c; HloInstruction *a, *b, *c;
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
@ -872,10 +889,11 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
HloInstruction* zero_like_a = BroadcastZeros( HloInstruction* zero_like_a = BroadcastZeros(
computation, a->shape().element_type(), a->shape().dimensions()); computation, a->shape().element_type(), a->shape().dimensions());
Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
simplifier->UpdateLayout(&changed_shape);
auto* dividend_is_negative = auto* dividend_is_negative =
computation->AddInstruction(HloInstruction::CreateCompare( computation->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, changed_shape, a, zero_like_a, ComparisonDirection::kLt));
ComparisonDirection::kLt));
auto* negated_dividend = computation->AddInstruction( auto* negated_dividend = computation->AddInstruction(
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
@ -887,8 +905,8 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
int log2_abs_b_value = tensorflow::Log2Floor64(b_value); int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
auto* shift_amount = auto* shift_amount = computation->AddInstruction(
computation->AddInstruction(HloInstruction::CreateConstant( simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(log2_abs_b_value))); LiteralUtil::CreateR0<T>(log2_abs_b_value)));
if (!ShapeUtil::IsScalar(b->shape())) { if (!ShapeUtil::IsScalar(b->shape())) {
shift_amount = computation->AddInstruction( shift_amount = computation->AddInstruction(
@ -911,8 +929,8 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
uint64 b_value = c->literal().GetFirstElement<T>(); uint64 b_value = c->literal().GetFirstElement<T>();
if (IsPowerOfTwo(b_value)) { if (IsPowerOfTwo(b_value)) {
int log2_abs_b_value = tensorflow::Log2Floor64(b_value); int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
HloInstruction* shift_amount = HloInstruction* shift_amount = computation->AddInstruction(
computation->AddInstruction(HloInstruction::CreateConstant( simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(log2_abs_b_value))); LiteralUtil::CreateR0<T>(log2_abs_b_value)));
if (!ShapeUtil::IsScalar(b->shape())) { if (!ShapeUtil::IsScalar(b->shape())) {
shift_amount = computation->AddInstruction( shift_amount = computation->AddInstruction(
@ -940,49 +958,49 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
switch (divide->shape().element_type()) { switch (divide->shape().element_type()) {
case S8: case S8:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int8>(divide, computation_)) { TryDivideToShift<int8>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
case S16: case S16:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int16>(divide, computation_)) { TryDivideToShift<int16>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
case S32: case S32:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int32>(divide, computation_)) { TryDivideToShift<int32>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
case S64: case S64:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int64>(divide, computation_)) { TryDivideToShift<int64>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
case U8: case U8:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint8>(divide, computation_)) { TryDivideToShift<uint8>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
case U16: case U16:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint16>(divide, computation_)) { TryDivideToShift<uint16>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
case U32: case U32:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint32>(divide, computation_)) { TryDivideToShift<uint32>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
case U64: case U64:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint64>(divide, computation_)) { TryDivideToShift<uint64>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift)); return ReplaceWithNewInstruction(divide, std::move(shift));
} }
break; break;
@ -1084,7 +1102,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK(); return Status::OK();
} }
auto inverse = computation_->AddInstruction( auto inverse = computation_->AddInstruction(
HloInstruction::CreateConstant((new_literal.Clone()))); simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide, TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide); return ReplaceInstruction(divide, new_divide);
@ -1456,6 +1474,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
Shape rhs_slice_shape(rhs->shape()); Shape rhs_slice_shape(rhs->shape());
rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
simplifier_->UpdateLayout(&rhs_slice_shape);
std::array<int64, 2> start_indices; std::array<int64, 2> start_indices;
start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
@ -1591,6 +1610,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
right_operand->shape().dimensions(1 - rhs_contracting_dimension); right_operand->shape().dimensions(1 - rhs_contracting_dimension);
auto memoized_shape = auto memoized_shape =
ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
simplifier_->UpdateLayout(&memoized_shape);
auto* memoized_inst = computation_->AddInstruction( auto* memoized_inst = computation_->AddInstruction(
HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
dnums, dot->precision_config())); dnums, dot->precision_config()));
@ -1605,7 +1625,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
// Slice out start and 0 components and reorder if necessary. // Slice out start and 0 components and reorder if necessary.
auto indices_type = dynamic_slice->operand(1)->shape().element_type(); auto indices_type = dynamic_slice->operand(1)->shape().element_type();
Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
simplifier_->UpdateLayout(&s_shape);
Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
simplifier_->UpdateLayout(&d_shape);
HloInstruction* non_zero_start = HloInstruction* non_zero_start =
dynamic_slice->mutable_operand(1 + index_of_non_zero_start); dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
HloInstruction* zero_start = HloInstruction* zero_start =
@ -1638,7 +1660,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
if (ShapeUtil::IsZeroElementArray(dot->shape()) || if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) ||
ShapeUtil::IsZeroElementArray(rhs->shape())) { ShapeUtil::IsZeroElementArray(rhs->shape())) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( auto zero = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(dot->shape().element_type()))); LiteralUtil::Zero(dot->shape().element_type())));
return ReplaceWithNewInstruction( return ReplaceWithNewInstruction(
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
@ -2183,7 +2206,8 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
// zero. // zero.
auto* iota = Cast<HloIotaInstruction>(instruction); auto* iota = Cast<HloIotaInstruction>(instruction);
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( auto zero = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(iota->shape().element_type()).Clone())); LiteralUtil::Zero(iota->shape().element_type()).Clone()));
return ReplaceWithNewInstruction( return ReplaceWithNewInstruction(
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
@ -2307,7 +2331,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
HloInstruction *lhs, *rhs; HloInstruction *lhs, *rhs;
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) { if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant( auto one = simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::One(power->shape().element_type()).Clone()); LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones; std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) { if (ShapeUtil::IsScalar(power->shape())) {
@ -2342,7 +2366,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) { if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( auto* one = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::One(rhs->shape().element_type()).Clone())); LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit // Explicitly broadcast scalar 1 to the output shape, to avoid implicit
@ -2422,14 +2447,16 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
std::vector<HloInstruction*> new_operands; std::vector<HloInstruction*> new_operands;
new_operands.reserve(user->operand_count()); new_operands.reserve(user->operand_count());
Shape changed_shape;
for (HloInstruction* user_operand : user->operands()) { for (HloInstruction* user_operand : user->operands()) {
if (user_operand->opcode() == HloOpcode::kBroadcast && if (user_operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
changed_shape = ShapeUtil::ChangeElementType(
operand->shape(), user_operand->shape().element_type());
simplifier_->UpdateLayout(&changed_shape);
new_operands.push_back( new_operands.push_back(
computation_->AddInstruction(HloInstruction::CreateBroadcast( computation_->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType( changed_shape, user_operand->mutable_operand(0), {})));
operand->shape(), user_operand->shape().element_type()),
user_operand->mutable_operand(0), {})));
} else { } else {
CHECK_EQ(broadcast, user_operand); CHECK_EQ(broadcast, user_operand);
new_operands.push_back(operand); new_operands.push_back(operand);
@ -2438,11 +2465,11 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
VLOG(4) << "Sinking broadcast after user:"; VLOG(4) << "Sinking broadcast after user:";
VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old broadcast: " << broadcast->ToString();
VLOG(4) << " old user: " << user->ToString(); VLOG(4) << " old user: " << user->ToString();
HloInstruction* new_user = changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
computation_->AddInstruction(user->CloneWithNewOperands( user->shape().element_type());
ShapeUtil::ChangeElementType(operand->shape(), simplifier_->UpdateLayout(&changed_shape);
user->shape().element_type()), HloInstruction* new_user = computation_->AddInstruction(
new_operands)); user->CloneWithNewOperands(changed_shape, new_operands));
VLOG(4) << " new user: " << new_user->ToString(); VLOG(4) << " new user: " << new_user->ToString();
HloInstruction* new_broadcast = HloInstruction* new_broadcast =
computation_->AddInstruction(HloInstruction::CreateBroadcast( computation_->AddInstruction(HloInstruction::CreateBroadcast(
@ -2456,8 +2483,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
namespace { namespace {
template <typename T> template <typename T>
std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder, std::unique_ptr<HloInstruction> TryRemainderToAnd(
HloComputation* computation) { HloInstruction* remainder, HloComputation* computation,
AlgebraicSimplifier* simplifier) {
HloInstruction *a, *b, *c; HloInstruction *a, *b, *c;
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
@ -2487,8 +2515,8 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
a->shape(), HloOpcode::kSelect, dividend_is_negative, a->shape(), HloOpcode::kSelect, dividend_is_negative,
negated_dividend, a)); negated_dividend, a));
auto* mask_amount = auto* mask_amount = computation->AddInstruction(
computation->AddInstruction(HloInstruction::CreateConstant( simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(b_value - 1))); LiteralUtil::CreateR0<T>(b_value - 1)));
if (!ShapeUtil::IsScalar(b->shape())) { if (!ShapeUtil::IsScalar(b->shape())) {
mask_amount = computation->AddInstruction( mask_amount = computation->AddInstruction(
@ -2509,8 +2537,8 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
} else { } else {
uint64 b_value = c->literal().GetFirstElement<T>(); uint64 b_value = c->literal().GetFirstElement<T>();
if (IsPowerOfTwo(b_value)) { if (IsPowerOfTwo(b_value)) {
HloInstruction* mask_amount = HloInstruction* mask_amount = computation->AddInstruction(
computation->AddInstruction(HloInstruction::CreateConstant( simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(b_value - 1))); LiteralUtil::CreateR0<T>(b_value - 1)));
if (!ShapeUtil::IsScalar(b->shape())) { if (!ShapeUtil::IsScalar(b->shape())) {
mask_amount = computation->AddInstruction( mask_amount = computation->AddInstruction(
@ -2532,49 +2560,49 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
switch (remainder->shape().element_type()) { switch (remainder->shape().element_type()) {
case S8: case S8:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int8>(remainder, computation_)) { TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
case S16: case S16:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int16>(remainder, computation_)) { TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
case S32: case S32:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int32>(remainder, computation_)) { TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
case S64: case S64:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int64>(remainder, computation_)) { TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
case U8: case U8:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint8>(remainder, computation_)) { TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
case U16: case U16:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint16>(remainder, computation_)) { TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
case U32: case U32:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint32>(remainder, computation_)) { TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
case U64: case U64:
if (std::unique_ptr<HloInstruction> shift = if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint64>(remainder, computation_)) { TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift)); return ReplaceWithNewInstruction(remainder, std::move(shift));
} }
break; break;
@ -2597,7 +2625,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
if (!LayoutUtil::HasLayout(reshaped_shape)) { if (!LayoutUtil::HasLayout(reshaped_shape)) {
LayoutUtil::SetToDefaultLayout(&reshaped_shape); LayoutUtil::SetToDefaultLayout(&reshaped_shape);
} }
auto empty_constant = HloInstruction::CreateConstant( auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
Literal::CreateFromShape(reshaped_shape)); Literal::CreateFromShape(reshaped_shape));
return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
@ -2810,6 +2838,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
new_slice_limits), new_slice_limits),
new_slice_operand, new_slice_starts, new_slice_limits, new_slice_operand, new_slice_starts, new_slice_limits,
new_slice_stides)); new_slice_stides));
simplifier_->UpdateLayout(new_slice->mutable_shape());
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
return true; return true;
@ -2914,8 +2943,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// representative. // representative.
auto arg = reduce->inputs()[0]; auto arg = reduce->inputs()[0];
auto init_value = reduce->init_values()[0]; auto init_value = reduce->init_values()[0];
const Shape& reduce_result_shape = Shape& reduce_result_shape = const_cast<Shape&>(
multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape());
absl::Span<const int64> dimensions(reduce->dimensions()); absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply(); HloComputation* function = reduce->to_apply();
@ -3136,6 +3165,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return !absl::c_linear_search(effective_reduce_dims, dim); return !absl::c_linear_search(effective_reduce_dims, dim);
}, },
reduce_window->shape()); reduce_window->shape());
simplifier_->UpdateLayout(&reduce_shape);
HloInstruction* reduce = HloInstruction* reduce =
computation_->AddInstruction(HloInstruction::CreateReduce( computation_->AddInstruction(HloInstruction::CreateReduce(
/*shape=*/reduce_shape, /*shape=*/reduce_shape,
@ -3261,11 +3291,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* new_reduce_window_operand; HloInstruction* new_reduce_window_operand;
if (convert != nullptr) { if (convert != nullptr) {
new_reduce_window_operand = Shape changed_shape = ShapeUtil::ChangeElementType(
computation_->AddInstruction(HloInstruction::CreateConvert( pad_operand->shape(), convert->shape().element_type());
ShapeUtil::ChangeElementType(pad_operand->shape(), simplifier_->UpdateLayout(&changed_shape);
convert->shape().element_type()), new_reduce_window_operand = computation_->AddInstruction(
pad_operand)); HloInstruction::CreateConvert(changed_shape, pad_operand));
} else { } else {
new_reduce_window_operand = pad_operand; new_reduce_window_operand = pad_operand;
} }
@ -3614,15 +3644,18 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
// We already checked feature_dimension is most minor, so data in input_shape // We already checked feature_dimension is most minor, so data in input_shape
// and row-major {conv_width,input_channels} are bitwise identical. // and row-major {conv_width,input_channels} are bitwise identical.
const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {conv_width, input_channels}); input_shape.element_type(), {conv_width, input_channels});
simplifier_->UpdateLayout(&new_input_shape);
// We already checked input_feature_dimension is more major than // We already checked input_feature_dimension is more major than
// output_feature_dimension, so data in filter_shape and row-major // output_feature_dimension, so data in filter_shape and row-major
// {input_channels,output_channels} are bitwise identical. // {input_channels,output_channels} are bitwise identical.
const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
filter_shape.element_type(), {input_channels, output_channels}); filter_shape.element_type(), {input_channels, output_channels});
const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( simplifier_->UpdateLayout(&new_filter_shape);
Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
convolution_shape.element_type(), {conv_width, output_channels}); convolution_shape.element_type(), {conv_width, output_channels});
simplifier_->UpdateLayout(&dot_output_shape);
auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_lhs = add_bitcast(new_input_shape, lhs);
auto new_rhs = add_bitcast(new_filter_shape, rhs); auto new_rhs = add_bitcast(new_filter_shape, rhs);
@ -3647,7 +3680,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
convolution, convolution,
HloInstruction::CreateBroadcast( HloInstruction::CreateBroadcast(
convolution->shape(), convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant( computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(convolution->shape().element_type()))), LiteralUtil::Zero(convolution->shape().element_type()))),
{})); {}));
} }
@ -3731,7 +3765,7 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
"AlgebraicSimplifier::Run(), before:\n" + module->ToString()); "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
bool changed = false; bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) { for (auto* comp : module->MakeNonfusionComputations()) {
if (AlgebraicSimplifierVisitor::Run(comp, options_)) { if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) {
changed = true; changed = true;
} }
} }

View File

@ -105,6 +105,15 @@ class AlgebraicSimplifier : public HloModulePass {
// computation was changed. // computation was changed.
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
// Create constant from literal with tiles and element size updated in the
// constant's layout.
std::unique_ptr<HloInstruction> CreateConstantWithLayoutUpdated(
Literal literal) {
auto constant = HloInstruction::CreateConstant(std::move(literal));
UpdateLayout(constant->mutable_shape());
return constant;
}
private: private:
AlgebraicSimplifierOptions options_; AlgebraicSimplifierOptions options_;
}; };

View File

@ -29,8 +29,11 @@ namespace xla {
class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
public: public:
explicit BFloat16ConversionFoldingVisitor( explicit BFloat16ConversionFoldingVisitor(
HloComputation* computation, const BFloat16Support* bfloat16_support) HloComputation* computation, const BFloat16Support* bfloat16_support,
: computation_(computation), bfloat16_support_(bfloat16_support) {} BFloat16ConversionFolding* bfloat16_conversion_folding)
: computation_(computation),
bfloat16_support_(bfloat16_support),
bfloat16_conversion_folding_(bfloat16_conversion_folding) {}
Status DefaultAction(HloInstruction* hlo) override; Status DefaultAction(HloInstruction* hlo) override;
@ -38,8 +41,10 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllReduce(HloInstruction* crs) override;
static bool Run(HloComputation* computation, static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) { const BFloat16Support* bfloat16_support,
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); BFloat16ConversionFolding* bfloat16_conversion_folding) {
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support,
bfloat16_conversion_folding);
TF_CHECK_OK(computation->Accept(&visitor)); TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_; return visitor.changed_;
} }
@ -61,6 +66,7 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
HloComputation* computation_; HloComputation* computation_;
const BFloat16Support* bfloat16_support_; const BFloat16Support* bfloat16_support_;
BFloat16ConversionFolding* bfloat16_conversion_folding_;
bool changed_ = false; bool changed_ = false;
}; };
@ -68,6 +74,7 @@ Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
HloInstruction* hlo) { HloInstruction* hlo) {
std::vector<HloInstruction*> materialized_users = hlo->users(); std::vector<HloInstruction*> materialized_users = hlo->users();
hlo->mutable_shape()->set_element_type(BF16); hlo->mutable_shape()->set_element_type(BF16);
bfloat16_conversion_folding_->UpdateLayout(hlo->mutable_shape());
for (auto user : materialized_users) { for (auto user : materialized_users) {
CHECK_EQ(user->opcode(), HloOpcode::kConvert); CHECK_EQ(user->opcode(), HloOpcode::kConvert);
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
@ -228,6 +235,8 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}) ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})
->set_element_type(BF16); ->set_element_type(BF16);
bfloat16_conversion_folding_->UpdateLayout(
ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}));
for (auto gte : per_tuple_element_gtes[i]) { for (auto gte : per_tuple_element_gtes[i]) {
TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); TF_RETURN_IF_ERROR(FoldOutputConversions(gte));
} }
@ -241,7 +250,7 @@ StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
bool changed = false; bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) { for (auto* comp : module->MakeNonfusionComputations()) {
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_, this)) {
changed = true; changed = true;
} }
} }

View File

@ -29,15 +29,20 @@ namespace xla {
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
public: public:
explicit BFloat16NormalizationVisitor(HloComputation* computation, explicit BFloat16NormalizationVisitor(
const BFloat16Support* bfloat16_support) HloComputation* computation, const BFloat16Support* bfloat16_support,
: computation_(computation), bfloat16_support_(bfloat16_support) {} BFloat16Normalization* bfloat16_normalization)
: computation_(computation),
bfloat16_support_(bfloat16_support),
bfloat16_normalization_(bfloat16_normalization) {}
Status DefaultAction(HloInstruction* hlo) override; Status DefaultAction(HloInstruction* hlo) override;
static bool Run(HloComputation* computation, static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) { const BFloat16Support* bfloat16_support,
BFloat16NormalizationVisitor visitor(computation, bfloat16_support); BFloat16Normalization* bfloat16_normalization) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support,
bfloat16_normalization);
TF_CHECK_OK(computation->Accept(&visitor)); TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_; return visitor.changed_;
} }
@ -73,6 +78,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
HloComputation* computation_; HloComputation* computation_;
const BFloat16Support* bfloat16_support_; const BFloat16Support* bfloat16_support_;
BFloat16Normalization* bfloat16_normalization_;
bool changed_ = false; bool changed_ = false;
}; };
@ -95,6 +101,7 @@ Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
computation->set_root_instruction(convert); computation->set_root_instruction(convert);
} }
convert->mutable_shape()->set_element_type(to); convert->mutable_shape()->set_element_type(to);
bfloat16_normalization_->UpdateLayout(convert->mutable_shape());
changed_ = true; changed_ = true;
return Status::OK(); return Status::OK();
} }
@ -103,6 +110,7 @@ Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
auto original_type = hlo->shape().element_type(); auto original_type = hlo->shape().element_type();
hlo->mutable_shape()->set_element_type(to); hlo->mutable_shape()->set_element_type(to);
bfloat16_normalization_->UpdateLayout(hlo->mutable_shape());
return InsertConvertAfterOutput(hlo, original_type, computation); return InsertConvertAfterOutput(hlo, original_type, computation);
} }
@ -110,8 +118,10 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
HloInstruction* hlo, int64 operand_idx, PrimitiveType to, HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
HloComputation* computation) { HloComputation* computation) {
auto operand = hlo->mutable_operand(operand_idx); auto operand = hlo->mutable_operand(operand_idx);
auto convert = computation->AddInstruction(HloInstruction::CreateConvert( auto shape = ShapeUtil::ChangeElementType(operand->shape(), to);
ShapeUtil::ChangeElementType(operand->shape(), to), operand)); bfloat16_normalization_->UpdateLayout(&shape);
auto convert = computation->AddInstruction(
HloInstruction::CreateConvert(shape, operand));
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
changed_ = true; changed_ = true;
return Status::OK(); return Status::OK();
@ -243,11 +253,13 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
continue; continue;
} }
subshape->set_element_type(F32); subshape->set_element_type(F32);
bfloat16_normalization_->UpdateLayout(subshape);
auto gte = computation_->AddInstruction( auto gte = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
auto shape = ShapeUtil::ChangeElementType(*subshape, BF16);
bfloat16_normalization_->UpdateLayout(&shape);
output_elements[i] = output_elements[i] =
computation_->AddInstruction(HloInstruction::CreateConvert( computation_->AddInstruction(HloInstruction::CreateConvert(shape, gte));
ShapeUtil::ChangeElementType(*subshape, BF16), gte));
} }
auto tuple = computation_->AddInstruction( auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple(output_elements)); HloInstruction::CreateTuple(output_elements));
@ -401,7 +413,7 @@ StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); 2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
bool changed = false; bool changed = false;
for (auto* comp : module->MakeComputationPostOrder()) { for (auto* comp : module->MakeComputationPostOrder()) {
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) { if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_, this)) {
changed = true; changed = true;
} }
} }

View File

@ -674,11 +674,13 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
if (hlo->opcode() != HloOpcode::kConstant) { if (hlo->opcode() != HloOpcode::kConstant) {
continue; continue;
} }
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
hlo->shape())) {
TF_ASSIGN_OR_RETURN(auto converted_literal, TF_ASSIGN_OR_RETURN(auto converted_literal,
hlo->literal().ConvertToShape(hlo->shape())); hlo->literal().ConvertToShape(hlo->shape()));
auto new_constant = computation->AddInstruction( auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal))); HloInstruction::CreateConstant(std::move(converted_literal)));
UpdateLayout(new_constant->mutable_shape());
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
} }
} }
@ -797,6 +799,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
auto subshape = entry.first; auto subshape = entry.first;
CHECK_EQ(subshape->element_type(), F32); CHECK_EQ(subshape->element_type(), F32);
subshape->set_element_type(BF16); subshape->set_element_type(BF16);
UpdateLayout(subshape);
changed_ = true; changed_ = true;
} }
} }

View File

@ -716,7 +716,8 @@ ProgramShape HloComputation::ComputeProgramShape() const {
return program_shape; return program_shape;
} }
bool HloComputation::operator==(const HloComputation& other) const { bool HloComputation::Equal(const HloComputation& other,
bool is_layout_sensitive) const {
if (this == &other) { if (this == &other) {
return true; return true;
} }
@ -741,7 +742,8 @@ bool HloComputation::operator==(const HloComputation& other) const {
[](const HloInstruction*, const HloInstruction*) { return true; }, [](const HloInstruction*, const HloInstruction*) { return true; },
[](const HloComputation* a, const HloComputation* b) { [](const HloComputation* a, const HloComputation* b) {
return *a == *b; return *a == *b;
}); },
is_layout_sensitive);
if (!identical_ignoring_operands) { if (!identical_ignoring_operands) {
return false; return false;
} }

View File

@ -270,7 +270,12 @@ class HloComputation {
ProgramShape ComputeProgramShape() const; ProgramShape ComputeProgramShape() const;
// Return whether `*this` and `other` are functionally equivalent. // Return whether `*this` and `other` are functionally equivalent.
bool operator==(const HloComputation& other) const; bool Equal(const HloComputation& other, bool is_layout_sensitive) const;
// Return whether `*this` and `other` are functionally equivalent.
bool operator==(const HloComputation& other) const {
return Equal(other, true);
}
// Replaces old instruction with newly created instruction. Removes old // Replaces old instruction with newly created instruction. Removes old
// instruction from computation. Updates uses and root instruction. // instruction from computation. Updates uses and root instruction.

View File

@ -1804,8 +1804,9 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
// Out of convenience the literal may have been produced with a different // Out of convenience the literal may have been produced with a different
// layout. Relayout as indicated by the HLO instruction. // layout. Relayout as indicated by the HLO instruction.
if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), if (!Layout::Equal().MinorToMajorOnly()(
hlo->shape())) { GetEvaluatedLiteralFor(hlo).shape().layout(),
hlo->shape().layout())) {
evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
} }
return Status::OK(); return Status::OK();

View File

@ -303,6 +303,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_ASSIGN_OR_RETURN(auto literal, TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal())); Literal::CreateFromProto(proto.literal()));
instruction = CreateConstant(std::move(literal)); instruction = CreateConstant(std::move(literal));
// Literal's shape may have no/different tiling info.
CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(instruction->shape(),
shape));
*instruction->mutable_shape() = shape;
} else { } else {
instruction = absl::make_unique<HloConstantInstruction>(shape); instruction = absl::make_unique<HloConstantInstruction>(shape);
} }

View File

@ -1018,6 +1018,11 @@ HloConstantInstruction::HloConstantInstruction(Literal literal)
: HloInstruction(HloOpcode::kConstant, literal.shape()), : HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {} literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(Literal literal,
const Shape& shape)
: HloInstruction(HloOpcode::kConstant, shape),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape) HloConstantInstruction::HloConstantInstruction(const Shape& shape)
: HloInstruction(HloOpcode::kConstant, shape) {} : HloInstruction(HloOpcode::kConstant, shape) {}
@ -1063,7 +1068,12 @@ HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands, const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const { HloCloneContext* context) const {
CHECK(literal_.has_value()); CHECK(literal_.has_value());
return absl::make_unique<HloConstantInstruction>(literal_->Clone()); // Literal's shape may have no/different tiling info. Use this instruction's
// shape instead.
CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
this->shape()));
return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
this->shape());
} }
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(

View File

@ -650,6 +650,7 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction {
public: public:
explicit HloConstantInstruction(Literal literal); explicit HloConstantInstruction(Literal literal);
explicit HloConstantInstruction(Literal literal, const Shape& shape);
// Used when the literal is too large and dropped. // Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape); explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction. // Returns the literal associated with this instruction.

View File

@ -56,6 +56,14 @@ class HloModulePass : public HloPassInterface {
} }
return changed; return changed;
}; };
// Update the layout of a Shape to one that is supported by a given backend.
// One can call this function after modifying the Shape in case that modifying
// the Shape requires changes to the layout for the given Backend.
//
// TODO(b/129084868): Make this Backend dependent instead of requiring
// deriving from the pass the and overriding this function.
virtual void UpdateLayout(Shape* shape) {}
}; };
// Base class for passes which are module-group scoped. These passes cannot run // Base class for passes which are module-group scoped. These passes cannot run

View File

@ -343,7 +343,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
// Check that the 'compare' computation returns a PRED. // Check that the 'compare' computation returns a PRED.
Shape compare_shape = compare->root_instruction()->shape(); Shape compare_shape = compare->root_instruction()->shape();
if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError( return InternalError(
"The Sort compare computation shape does not lead to a scalar " "The Sort compare computation shape does not lead to a scalar "
"predicate shape: %s", "predicate shape: %s",
@ -393,7 +393,8 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
return InternalError("Constant is required to have a valid literal: %s", return InternalError("Constant is required to have a valid literal: %s",
constant->ToString()); constant->ToString());
} }
return CheckShape(constant, constant->literal().shape()); return CheckShape(constant, constant->literal().shape(),
/*only_compare_minor_to_major_in_layout=*/true);
} }
Status ShapeVerifier::HandleIota(HloInstruction* instruction) { Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
@ -654,7 +655,8 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape = const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape(); xla_while->while_condition()->root_instruction()->shape();
if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { if (!ShapeUtil::Compatible(conditional_shape,
ShapeUtil::MakeShape(PRED, {}))) {
return InternalError( return InternalError(
"Conditional computation shape does not lead to a scalar predicate " "Conditional computation shape does not lead to a scalar predicate "
"shape: %s", "shape: %s",
@ -696,7 +698,8 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) {
return CheckShape(send, return CheckShape(send,
ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()})); ShapeUtil::MakeTokenShape()}),
/*only_compare_minor_to_major_in_layout=*/true);
} }
Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
@ -705,9 +708,11 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
Status ShapeVerifier::HandleRecv(HloInstruction* recv) { Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
return CheckShape( return CheckShape(
recv, ShapeUtil::MakeTupleShape( recv,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(recv->shape(), 0), {ShapeUtil::GetTupleElementShape(recv->shape(), 0),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})); ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
/*only_compare_minor_to_major_in_layout=*/true);
} }
Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
@ -844,7 +849,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
} }
Status ShapeVerifier::CheckShape(const HloInstruction* instruction, Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape) { const Shape& inferred_shape,
bool only_compare_minor_to_major_in_layout) {
// If allow_mixed_precision_ is false, check if there are operands with // If allow_mixed_precision_ is false, check if there are operands with
// different precisions. We need this check because ShapeInference allows // different precisions. We need this check because ShapeInference allows
// mixed precision inputs. // mixed precision inputs.
@ -878,7 +884,8 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
case HloOpcode::kTuple: case HloOpcode::kTuple:
case HloOpcode::kTupleSelect: case HloOpcode::kTupleSelect:
case HloOpcode::kWhile: case HloOpcode::kWhile:
return ShapesSame(instruction->shape(), inferred_shape); return ShapesSame(instruction->shape(), inferred_shape,
only_compare_minor_to_major_in_layout);
// We allow arbitrary layout and f32->bf16 transformations on all other // We allow arbitrary layout and f32->bf16 transformations on all other
// instructions, although this may be made more strict pending discussion // instructions, although this may be made more strict pending discussion

View File

@ -106,7 +106,8 @@ class ShapeVerifier : public DfsHloVisitor {
// Check the instruction's shape against the shape given by ShapeInference // Check the instruction's shape against the shape given by ShapeInference
// and return an appropriate error if there is a mismatch. // and return an appropriate error if there is a mismatch.
Status CheckShape(const HloInstruction* instruction, Status CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape); const Shape& inferred_shape,
bool only_compare_minor_to_major_in_layout = false);
// Overload which takes a StatusOr to reduce boilerplate in the caller. // Overload which takes a StatusOr to reduce boilerplate in the caller.
Status CheckShape(const HloInstruction* instruction, Status CheckShape(const HloInstruction* instruction,
@ -120,14 +121,31 @@ class ShapeVerifier : public DfsHloVisitor {
private: private:
// Helpers that switch on layout_sensitive_. // Helpers that switch on layout_sensitive_.
bool ShapesSame(const Shape& a, const Shape& b) { bool ShapesSame(const Shape& a, const Shape& b,
return layout_sensitive_ ? ShapeUtil::Equal(a, b) bool minor_to_major_only = false) {
: ShapeUtil::Compatible(a, b); if (!layout_sensitive_) {
return ShapeUtil::Compatible(a, b);
} }
bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { Shape::Equal equal;
return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) if (minor_to_major_only) {
: ShapeUtil::CompatibleIgnoringFpPrecision(a, b); equal.MinorToMajorOnlyInLayout();
} }
return equal(a, b);
}
bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b,
bool minor_to_major_only = false) {
if (!layout_sensitive_) {
return ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
}
Shape::Equal equal;
if (minor_to_major_only) {
equal.MinorToMajorOnlyInLayout();
}
equal.IgnoreFpPrecision();
return equal(a, b);
}
string StringifyShape(const Shape& s) { string StringifyShape(const Shape& s) {
return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
: ShapeUtil::HumanString(s); : ShapeUtil::HumanString(s);

View File

@ -173,7 +173,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
auto iter = buffer_constraints_.find(&buffer); auto iter = buffer_constraints_.find(&buffer);
if (iter != buffer_constraints_.end()) { if (iter != buffer_constraints_.end()) {
const BufferLayoutConstraint& curr_constraint = iter->second; const BufferLayoutConstraint& curr_constraint = iter->second;
if (LayoutUtil::Equal(curr_constraint.layout(), layout)) { if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
// New constraint matches existing constraint. Nothing to do. // New constraint matches existing constraint. Nothing to do.
return Status::OK(); return Status::OK();
} }
@ -210,7 +210,7 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
GetOperandLayoutConstraint(instruction, operand_no); GetOperandLayoutConstraint(instruction, operand_no);
if (curr_shape_layout != nullptr) { if (curr_shape_layout != nullptr) {
if (curr_shape_layout->shape_layout().MatchesLayoutInShape( if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
shape_with_layout)) { shape_with_layout, /*minor_to_major_only=*/true)) {
// New constraint matches existing constraint. Nothing to do. // New constraint matches existing constraint. Nothing to do.
return Status::OK(); return Status::OK();
} }
@ -269,7 +269,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
const ShapeLayout* curr_shape_layout = ResultLayout(); const ShapeLayout* curr_shape_layout = ResultLayout();
if (curr_shape_layout != nullptr) { if (curr_shape_layout != nullptr) {
if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { if (!curr_shape_layout->MatchesLayoutInShape(
shape_with_layout, /*minor_to_major_only=*/true)) {
return FailedPrecondition( return FailedPrecondition(
"Result of computation %s already has the layout constraint %s, " "Result of computation %s already has the layout constraint %s, "
"cannot add incompatible constraint %s", "cannot add incompatible constraint %s",
@ -647,6 +648,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
namespace { namespace {
bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
}
// The operands of a call must match the layouts of parameters in the // The operands of a call must match the layouts of parameters in the
// ComputationLayout, and the call instruction itself must match the result // ComputationLayout, and the call instruction itself must match the result
// layout in the ComputationLayout. // layout in the ComputationLayout.
@ -656,10 +661,10 @@ Status CheckCallLayout(HloInstruction* call,
TF_RET_CHECK(computation->num_parameters() == call->operand_count()); TF_RET_CHECK(computation->num_parameters() == call->operand_count());
for (int64 i = 0; i < computation->num_parameters(); ++i) { for (int64 i = 0; i < computation->num_parameters(); ++i) {
TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
call->operand(i)->shape())); call->operand(i)->shape(), /*minor_to_major_only=*/true));
} }
TF_RET_CHECK( TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
computation_layout.result_layout().MatchesLayoutInShape(call->shape())); call->shape(), /*minor_to_major_only=*/true));
return Status::OK(); return Status::OK();
} }
@ -670,8 +675,8 @@ Status CheckCustomCallLayout(HloInstruction* instruction) {
const HloCustomCallInstruction* custom_call = const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction); DynCast<HloCustomCallInstruction>(instruction);
for (int64 i = 0; i < custom_call->operand_count(); ++i) { for (int64 i = 0; i < custom_call->operand_count(); ++i) {
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( TF_RET_CHECK(
custom_call->operand(i)->shape(), LayoutsInShapesEqual(custom_call->operand(i)->shape(),
custom_call->operand_shapes_with_layout()[i])); custom_call->operand_shapes_with_layout()[i]));
} }
} }
@ -690,13 +695,12 @@ Status CheckWhileLayout(HloInstruction* while_inst,
auto init_shape = while_inst->operand(0)->shape(); auto init_shape = while_inst->operand(0)->shape();
TF_RET_CHECK( TF_RET_CHECK(
condition_computation_layout.parameter_layout(0).MatchesLayoutInShape( condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
init_shape)); init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape( TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
init_shape)); init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK( TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
body_computation_layout.result_layout().MatchesLayoutInShape(init_shape)); init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK( TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
return Status::OK(); return Status::OK();
} }
@ -709,13 +713,14 @@ Status CheckConditionalLayout(
branch_computation_layouts[j].result_layout()); branch_computation_layouts[j].result_layout());
TF_RET_CHECK( TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape( branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->shape())); instruction->shape(), /*minor_to_major_only=*/true));
TF_RET_CHECK( TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape( branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->branch_computation(j)->root_instruction()->shape())); instruction->branch_computation(j)->root_instruction()->shape(),
/*minor_to_major_only=*/true));
TF_RET_CHECK( TF_RET_CHECK(
branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape( branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
branch_operand->shape())); branch_operand->shape(), /*minor_to_major_only=*/true));
} }
return Status::OK(); return Status::OK();
} }
@ -726,11 +731,11 @@ Status CheckConditionalLayout(
Status CheckFusionLayout(HloInstruction* fusion) { Status CheckFusionLayout(HloInstruction* fusion) {
TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode()); TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
fusion->shape(), fusion->fused_expression_root()->shape())); fusion->fused_expression_root()->shape()));
for (int64 i = 0; i < fusion->operand_count(); ++i) { for (int64 i = 0; i < fusion->operand_count(); ++i) {
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape())); fusion->operand(i)->shape()));
} }
return Status::OK(); return Status::OK();
} }
@ -742,7 +747,8 @@ Status CheckParameterLayout(HloInstruction* parameter,
const ShapeLayout& parameter_layout = const ShapeLayout& parameter_layout =
computation_layout.parameter_layout(parameter->parameter_number()); computation_layout.parameter_layout(parameter->parameter_number());
if (parameter_layout.LayoutIsSet() && if (parameter_layout.LayoutIsSet() &&
!parameter_layout.MatchesLayoutInShape(parameter->shape())) { !parameter_layout.MatchesLayoutInShape(parameter->shape(),
/*minor_to_major_only=*/true)) {
return InternalError( return InternalError(
"parameter instruction %s does not match layout of computation " "parameter instruction %s does not match layout of computation "
"shape: %s", "shape: %s",
@ -753,8 +759,7 @@ Status CheckParameterLayout(HloInstruction* parameter,
// The layout of a constant instruction must match the layout of its literal. // The layout of a constant instruction must match the layout of its literal.
Status CheckConstantLayout(HloInstruction* constant) { Status CheckConstantLayout(HloInstruction* constant) {
if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(), if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
constant->shape())) {
return InternalError( return InternalError(
"constant instruction %s does not match the layout of its literal %s", "constant instruction %s does not match the layout of its literal %s",
constant->ToString(), constant->ToString(),
@ -785,7 +790,8 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
HloInstruction* gte = instruction->parent()->AddInstruction( HloInstruction* gte = instruction->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(instr_shape, instruction, i)); HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
if (ShapeUtil::Equal(target_shape, instr_shape)) { if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
instr_shape)) {
// Shapes and layouts are equal, no need to copy. // Shapes and layouts are equal, no need to copy.
element_copies.push_back(gte); element_copies.push_back(gte);
} else { } else {
@ -831,7 +837,8 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
TF_RET_CHECK(operand_layout.LayoutIsSet()); TF_RET_CHECK(operand_layout.LayoutIsSet());
TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
operand->shape())) {
VLOG(5) << "Operand " << operand->ToString() << " layout matches in " VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
<< instruction->ToString(); << instruction->ToString();
// Operand layout already matches our constraint. Nothing to do. // Operand layout already matches our constraint. Nothing to do.
@ -892,7 +899,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
const Shape& instruction_subshape = const Shape& instruction_subshape =
ShapeUtil::GetSubshape(instruction->shape(), index); ShapeUtil::GetSubshape(instruction->shape(), index);
for (const LogicalBuffer* buffer : buffers) { for (const LogicalBuffer* buffer : buffers) {
if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) { if (!Shape::Equal().MinorToMajorOnlyInLayout()(
instruction_subshape, buffer->shape())) {
return InternalError( return InternalError(
"Layout of instruction %s at index {%s} does not match " "Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s", "source LogicalBuffer %s: %s vs %s",
@ -954,8 +962,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, module->entry_computation()) FindOrDie(computation_layouts_, module->entry_computation())
.result_layout(); .result_layout();
if (result_layout.LayoutIsSet()) { if (result_layout.LayoutIsSet()) {
TF_RET_CHECK( TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
ShapeUtil::Equal(module->result_shape(), result_layout.shape())); module->result_shape(), result_layout.shape()));
} }
return Status::OK(); return Status::OK();
} }
@ -1510,8 +1518,8 @@ StatusOr<Layout> InferArrayLayout(
if (first_buffer_layout == nullptr) { if (first_buffer_layout == nullptr) {
first_buffer_layout = &source_buffer->shape().layout(); first_buffer_layout = &source_buffer->shape().layout();
} else if (!LayoutUtil::Equal(source_buffer->shape().layout(), } else if (!Layout::Equal().MinorToMajorOnly()(
*first_buffer_layout)) { source_buffer->shape().layout(), *first_buffer_layout)) {
// The points-to set is ambiguous for this index and the different source // The points-to set is ambiguous for this index and the different source
// buffers have different layouts. This case is possible in valid XLA // buffers have different layouts. This case is possible in valid XLA
// computations because we do not propagate BufferLayoutConstraints to all // computations because we do not propagate BufferLayoutConstraints to all
@ -1789,7 +1797,8 @@ Status LayoutAssignment::RunOnComputation(
// layout constraint. // layout constraint.
if (constraints.ResultLayout() != nullptr && if (constraints.ResultLayout() != nullptr &&
!constraints.ResultLayout()->MatchesLayoutInShape( !constraints.ResultLayout()->MatchesLayoutInShape(
computation->root_instruction()->shape())) { computation->root_instruction()->shape(),
/*minor_to_major_only=*/true)) {
if (conditional_mismatch_.count(computation) > 0) { if (conditional_mismatch_.count(computation) > 0) {
*FindOrDie(computation_layouts_, computation).mutable_result_layout() = *FindOrDie(computation_layouts_, computation).mutable_result_layout() =
FindOrDie(conditional_mismatch_, computation).result_layout(); FindOrDie(conditional_mismatch_, computation).result_layout();
@ -1907,7 +1916,9 @@ Status LayoutAssignment::PropagateComputationLayouts(
<< ": " << computed_computation_layout.result_layout().ToString(); << ": " << computed_computation_layout.result_layout().ToString();
*result_layout = computed_computation_layout.result_layout(); *result_layout = computed_computation_layout.result_layout();
} else { } else {
TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
computed_computation_layout.result_layout().shape(),
result_layout->shape()));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -141,6 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
} }
} }
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
if (!ignore_layout_) { if (!ignore_layout_) {
if (lhs.layout().format() != rhs.layout().format()) { if (lhs.layout().format() != rhs.layout().format()) {
VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
@ -161,11 +166,6 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
} }
} }
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
if (!ignore_dynamic_dimension_) { if (!ignore_dynamic_dimension_) {
for (int i = 0; i < lhs.rank(); ++i) { for (int i = 0; i < lhs.rank(); ++i) {
if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) {

View File

@ -169,6 +169,11 @@ class Shape {
ignore_element_size_in_layout_ = true; ignore_element_size_in_layout_ = true;
return *this; return *this;
} }
Equal& MinorToMajorOnlyInLayout() {
ignore_tiles_in_layout_ = true;
ignore_element_size_in_layout_ = true;
return *this;
}
Equal& IgnoreElementType() { Equal& IgnoreElementType() {
ignore_element_type_ = true; ignore_element_type_ = true;
return *this; return *this;

View File

@ -46,8 +46,13 @@ void ShapeLayout::SetToDefaultLayout() {
LayoutUtil::SetToDefaultLayout(&shape_); LayoutUtil::SetToDefaultLayout(&shape_);
} }
bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const { bool ShapeLayout::MatchesLayoutInShape(const Shape& shape,
return ShapeUtil::Equal(shape, shape_); bool minor_to_major_only) const {
auto equal = Shape::Equal();
if (minor_to_major_only) {
equal.MinorToMajorOnlyInLayout();
}
return equal(shape, shape_);
} }
const Layout& ShapeLayout::layout() const { const Layout& ShapeLayout::layout() const {

View File

@ -45,7 +45,8 @@ class ShapeLayout {
// Returns true if the Layouts in this ShapeLayout match the layouts in the // Returns true if the Layouts in this ShapeLayout match the layouts in the
// given shape. Returns false otherwise. If the given shape is not compatible // given shape. Returns false otherwise. If the given shape is not compatible
// with the ShapeLayout's shape, then false is returned. // with the ShapeLayout's shape, then false is returned.
bool MatchesLayoutInShape(const Shape& shape) const; bool MatchesLayoutInShape(const Shape& shape,
bool minor_to_major_only = false) const;
// Copies the layout from the given shape into this ShapeLayout. 'other_shape' // Copies the layout from the given shape into this ShapeLayout. 'other_shape'
// must be compatible with the ShapeLayout's shape. // must be compatible with the ShapeLayout's shape.

View File

@ -102,6 +102,11 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
} }
TF_ASSIGN_OR_RETURN(Shape shape, TF_ASSIGN_OR_RETURN(Shape shape,
ShapeUtil::MakeValidatedShape(element_type, dimensions)); ShapeUtil::MakeValidatedShape(element_type, dimensions));
if (element_size_in_bits ==
ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
// Only set element_size_in_bits if it's different from the default value.
element_size_in_bits = 0;
}
*shape.mutable_layout() = *shape.mutable_layout() =
LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits); LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits);
if (!shape.has_layout()) { if (!shape.has_layout()) {
@ -219,7 +224,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
for (int i = 0; i < shape.dimensions_size(); ++i) { for (int i = 0; i < shape.dimensions_size(); ++i) {
dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
} }
return MakeShapeWithDescendingLayout(shape.element_type(), dims); Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
// Since the physical layout is kept the same, the tiles and element size are
// the same also.
*new_shape.mutable_layout()->mutable_tiles() = shape.layout().tiles();
new_shape.mutable_layout()->set_element_size_in_bits(
shape.layout().element_size_in_bits());
return new_shape;
} }
/* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type, /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,

View File

@ -395,6 +395,8 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
ParametricDotTestWithoutLayoutAssignment() { ParametricDotTestWithoutLayoutAssignment() {
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"layout-assignment"); "layout-assignment");
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"tiling-assignment");
// Disable algebraic simplification because the pass may replace a dot // Disable algebraic simplification because the pass may replace a dot
// instruction with a layout-changing multiplication instruction. // instruction with a layout-changing multiplication instruction.
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(

View File

@ -846,16 +846,16 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
xla::ProgramShape xla_program_shape = xla::ProgramShape xla_program_shape =
XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
EXPECT_TRUE(xla::LayoutUtil::Equal( EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
.layout())); .layout()));
EXPECT_TRUE(xla::LayoutUtil::Equal( EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(), xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1}) xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
.layout())); .layout()));
EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(), EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
xla_program_shape.result().layout())); program_shape.result().layout(), xla_program_shape.result().layout()));
} }
TEST(RawApiTest, DotGeneralWithLayoutTest) { TEST(RawApiTest, DotGeneralWithLayoutTest) {