Populate tiling info in Layout data.

PiperOrigin-RevId: 240596416
This commit is contained in:
A. Unique TensorFlower 2019-03-27 10:51:52 -07:00 committed by TensorFlower Gardener
parent c510b79d5c
commit f366225ec6
24 changed files with 336 additions and 167 deletions

View File

@ -96,8 +96,13 @@ string Layout::ToString() const {
}
bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
if (lhs.format() != rhs.format() ||
lhs.minor_to_major() != rhs.minor_to_major() ||
if (lhs.format() != rhs.format()) {
return false;
}
if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) {
return false;
}
if (lhs.format() == SPARSE &&
lhs.max_sparse_elements() != rhs.max_sparse_elements()) {
return false;
}

View File

@ -127,6 +127,12 @@ class Layout {
return *this;
}
Equal& MinorToMajorOnly() {
ignore_tiles_ = true;
ignore_element_size_ = true;
return *this;
}
private:
bool ignore_tiles_ = false;
bool ignore_element_size_ = false;

View File

@ -250,12 +250,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation.
static bool Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options);
const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier);
private:
explicit AlgebraicSimplifierVisitor(HloComputation* computation,
const AlgebraicSimplifierOptions& options)
: computation_(computation), options_(options) {}
const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier)
: computation_(computation), options_(options), simplifier_(simplifier) {}
// Transforms Dots where at least one input is a vector or has a degenerate
// dimension and converts it into a multiply and reduce. This should enable
@ -274,10 +276,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
if (hlo->shape().rank() == 1) {
return hlo;
}
return computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(hlo->shape().element_type(),
{ShapeUtil::ElementsIn(hlo->shape())}),
hlo));
auto hlo_instruction =
computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(hlo->shape().element_type(),
{ShapeUtil::ElementsIn(hlo->shape())}),
hlo));
simplifier_->UpdateLayout(hlo_instruction->mutable_shape());
return hlo_instruction;
}
// Converts to primitive type if the input hlo is not that type, otherwise
@ -287,8 +292,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
if (hlo->shape().element_type() == element_type) {
return hlo;
}
return computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
Shape changed_shape =
ShapeUtil::ChangeElementType(hlo->shape(), element_type);
simplifier_->UpdateLayout(&changed_shape);
return computation_->AddInstruction(
HloInstruction::CreateConvert(changed_shape, hlo));
}
// Transposes a dot operand such that the batch dimensions are the msot major,
@ -312,13 +320,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Helper method to perform and add reduction on a list of dimensions.
HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
HloInstruction* zero = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::FilterDimensions(
[&](int64 dim) { return !absl::c_linear_search(dims, dim); },
hlo->shape());
simplifier_->UpdateLayout(&shape);
return computation_->AddInstruction(HloInstruction::CreateReduce(
shape, hlo, zero, dims, AddReduce_computation));
}
@ -403,6 +412,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloComputation::Builder b("scalar_add_computation");
Shape shape = ShapeUtil::MakeShape(F32, {});
simplifier_->UpdateLayout(&shape);
auto scalar_lhs = b.AddInstruction(
HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
auto scalar_rhs = b.AddInstruction(
@ -440,13 +450,16 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Cached computation for adding two scalar F32.
HloComputation* scalar_add_computation_ = nullptr;
AlgebraicSimplifier* simplifier_ = nullptr;
};
} // namespace
bool AlgebraicSimplifierVisitor::Run(
HloComputation* computation, const AlgebraicSimplifierOptions& options) {
AlgebraicSimplifierVisitor visitor(computation, options);
bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier) {
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -713,6 +726,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
new_slice_shape.set_dimensions(
concatenate_dimension,
slice_end - operands[i]->slice_starts(concatenate_dimension));
simplifier_->UpdateLayout(&new_slice_shape);
auto new_limit_indices = operands[i]->slice_limits();
new_limit_indices[concatenate_dimension] = slice_end;
auto new_slice_op =
@ -775,18 +789,19 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
}
static HloInstruction* BuildTupleConstant(HloComputation* computation,
const LiteralSlice& literal) {
const LiteralSlice& literal,
AlgebraicSimplifier* simplifier) {
if (literal.shape().IsTuple()) {
std::vector<HloInstruction*> elems;
elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
elems.push_back(
BuildTupleConstant(computation, LiteralSlice(literal, {i})));
elems.push_back(BuildTupleConstant(
computation, LiteralSlice(literal, {i}), simplifier));
}
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
HloInstruction::CreateConstant(literal.Clone()));
simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
}
}
@ -795,7 +810,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// explicit Tuple instructions.
if (constant->shape().IsTuple()) {
return ReplaceInstruction(
constant, BuildTupleConstant(computation_, constant->literal()));
constant,
BuildTupleConstant(computation_, constant->literal(), simplifier_));
}
if (constant->shape().element_type() == TOKEN) {
@ -808,7 +824,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
return ReplaceWithNewInstruction(
constant,
HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
@ -854,8 +870,9 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) {
}
template <typename T>
std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
HloComputation* computation) {
std::unique_ptr<HloInstruction> TryDivideToShift(
HloInstruction* divide, HloComputation* computation,
AlgebraicSimplifier* simplifier) {
HloInstruction *a, *b, *c;
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
@ -872,10 +889,11 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
HloInstruction* zero_like_a = BroadcastZeros(
computation, a->shape().element_type(), a->shape().dimensions());
Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
simplifier->UpdateLayout(&changed_shape);
auto* dividend_is_negative =
computation->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
ComparisonDirection::kLt));
changed_shape, a, zero_like_a, ComparisonDirection::kLt));
auto* negated_dividend = computation->AddInstruction(
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
@ -887,8 +905,8 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
auto* shift_amount =
computation->AddInstruction(HloInstruction::CreateConstant(
auto* shift_amount = computation->AddInstruction(
simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
if (!ShapeUtil::IsScalar(b->shape())) {
shift_amount = computation->AddInstruction(
@ -911,8 +929,8 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
uint64 b_value = c->literal().GetFirstElement<T>();
if (IsPowerOfTwo(b_value)) {
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
HloInstruction* shift_amount =
computation->AddInstruction(HloInstruction::CreateConstant(
HloInstruction* shift_amount = computation->AddInstruction(
simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
if (!ShapeUtil::IsScalar(b->shape())) {
shift_amount = computation->AddInstruction(
@ -940,49 +958,49 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
switch (divide->shape().element_type()) {
case S8:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int8>(divide, computation_)) {
TryDivideToShift<int8>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
case S16:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int16>(divide, computation_)) {
TryDivideToShift<int16>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
case S32:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int32>(divide, computation_)) {
TryDivideToShift<int32>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
case S64:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<int64>(divide, computation_)) {
TryDivideToShift<int64>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
case U8:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint8>(divide, computation_)) {
TryDivideToShift<uint8>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
case U16:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint16>(divide, computation_)) {
TryDivideToShift<uint16>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
case U32:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint32>(divide, computation_)) {
TryDivideToShift<uint32>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
case U64:
if (std::unique_ptr<HloInstruction> shift =
TryDivideToShift<uint64>(divide, computation_)) {
TryDivideToShift<uint64>(divide, computation_, simplifier_)) {
return ReplaceWithNewInstruction(divide, std::move(shift));
}
break;
@ -1084,7 +1102,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}
auto inverse = computation_->AddInstruction(
HloInstruction::CreateConstant((new_literal.Clone())));
simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);
@ -1456,6 +1474,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
Shape rhs_slice_shape(rhs->shape());
rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
simplifier_->UpdateLayout(&rhs_slice_shape);
std::array<int64, 2> start_indices;
start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
@ -1591,6 +1610,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
right_operand->shape().dimensions(1 - rhs_contracting_dimension);
auto memoized_shape =
ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
simplifier_->UpdateLayout(&memoized_shape);
auto* memoized_inst = computation_->AddInstruction(
HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
dnums, dot->precision_config()));
@ -1605,7 +1625,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
// Slice out start and 0 components and reorder if necessary.
auto indices_type = dynamic_slice->operand(1)->shape().element_type();
Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
simplifier_->UpdateLayout(&s_shape);
Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
simplifier_->UpdateLayout(&d_shape);
HloInstruction* non_zero_start =
dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
HloInstruction* zero_start =
@ -1638,8 +1660,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
ShapeUtil::IsZeroElementArray(lhs->shape()) ||
ShapeUtil::IsZeroElementArray(rhs->shape())) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(dot->shape().element_type())));
auto zero = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(dot->shape().element_type())));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}
@ -2183,8 +2206,9 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
// zero.
auto* iota = Cast<HloIotaInstruction>(instruction);
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(iota->shape().element_type()).Clone()));
auto zero = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(iota->shape().element_type()).Clone()));
return ReplaceWithNewInstruction(
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
}
@ -2307,7 +2331,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
HloInstruction *lhs, *rhs;
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
auto one = simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
@ -2342,8 +2366,9 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::One(rhs->shape().element_type()).Clone()));
auto* one = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@ -2422,14 +2447,16 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
std::vector<HloInstruction*> new_operands;
new_operands.reserve(user->operand_count());
Shape changed_shape;
for (HloInstruction* user_operand : user->operands()) {
if (user_operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
changed_shape = ShapeUtil::ChangeElementType(
operand->shape(), user_operand->shape().element_type());
simplifier_->UpdateLayout(&changed_shape);
new_operands.push_back(
computation_->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(
operand->shape(), user_operand->shape().element_type()),
user_operand->mutable_operand(0), {})));
changed_shape, user_operand->mutable_operand(0), {})));
} else {
CHECK_EQ(broadcast, user_operand);
new_operands.push_back(operand);
@ -2438,11 +2465,11 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
VLOG(4) << "Sinking broadcast after user:";
VLOG(4) << " old broadcast: " << broadcast->ToString();
VLOG(4) << " old user: " << user->ToString();
HloInstruction* new_user =
computation_->AddInstruction(user->CloneWithNewOperands(
ShapeUtil::ChangeElementType(operand->shape(),
user->shape().element_type()),
new_operands));
changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
user->shape().element_type());
simplifier_->UpdateLayout(&changed_shape);
HloInstruction* new_user = computation_->AddInstruction(
user->CloneWithNewOperands(changed_shape, new_operands));
VLOG(4) << " new user: " << new_user->ToString();
HloInstruction* new_broadcast =
computation_->AddInstruction(HloInstruction::CreateBroadcast(
@ -2456,8 +2483,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
namespace {
template <typename T>
std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
HloComputation* computation) {
std::unique_ptr<HloInstruction> TryRemainderToAnd(
HloInstruction* remainder, HloComputation* computation,
AlgebraicSimplifier* simplifier) {
HloInstruction *a, *b, *c;
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
@ -2487,8 +2515,8 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
a->shape(), HloOpcode::kSelect, dividend_is_negative,
negated_dividend, a));
auto* mask_amount =
computation->AddInstruction(HloInstruction::CreateConstant(
auto* mask_amount = computation->AddInstruction(
simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(b_value - 1)));
if (!ShapeUtil::IsScalar(b->shape())) {
mask_amount = computation->AddInstruction(
@ -2509,8 +2537,8 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
} else {
uint64 b_value = c->literal().GetFirstElement<T>();
if (IsPowerOfTwo(b_value)) {
HloInstruction* mask_amount =
computation->AddInstruction(HloInstruction::CreateConstant(
HloInstruction* mask_amount = computation->AddInstruction(
simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(b_value - 1)));
if (!ShapeUtil::IsScalar(b->shape())) {
mask_amount = computation->AddInstruction(
@ -2532,49 +2560,49 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
switch (remainder->shape().element_type()) {
case S8:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int8>(remainder, computation_)) {
TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
case S16:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int16>(remainder, computation_)) {
TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
case S32:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int32>(remainder, computation_)) {
TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
case S64:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<int64>(remainder, computation_)) {
TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
case U8:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint8>(remainder, computation_)) {
TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
case U16:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint16>(remainder, computation_)) {
TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
case U32:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint32>(remainder, computation_)) {
TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
case U64:
if (std::unique_ptr<HloInstruction> shift =
TryRemainderToAnd<uint64>(remainder, computation_)) {
TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) {
return ReplaceWithNewInstruction(remainder, std::move(shift));
}
break;
@ -2597,7 +2625,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
if (!LayoutUtil::HasLayout(reshaped_shape)) {
LayoutUtil::SetToDefaultLayout(&reshaped_shape);
}
auto empty_constant = HloInstruction::CreateConstant(
auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
Literal::CreateFromShape(reshaped_shape));
return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
@ -2810,6 +2838,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
new_slice_limits),
new_slice_operand, new_slice_starts, new_slice_limits,
new_slice_stides));
simplifier_->UpdateLayout(new_slice->mutable_shape());
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
return true;
@ -2914,8 +2943,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// representative.
auto arg = reduce->inputs()[0];
auto init_value = reduce->init_values()[0];
const Shape& reduce_result_shape =
multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
Shape& reduce_result_shape = const_cast<Shape&>(
multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape());
absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
@ -3136,6 +3165,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return !absl::c_linear_search(effective_reduce_dims, dim);
},
reduce_window->shape());
simplifier_->UpdateLayout(&reduce_shape);
HloInstruction* reduce =
computation_->AddInstruction(HloInstruction::CreateReduce(
/*shape=*/reduce_shape,
@ -3261,11 +3291,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* new_reduce_window_operand;
if (convert != nullptr) {
new_reduce_window_operand =
computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(pad_operand->shape(),
convert->shape().element_type()),
pad_operand));
Shape changed_shape = ShapeUtil::ChangeElementType(
pad_operand->shape(), convert->shape().element_type());
simplifier_->UpdateLayout(&changed_shape);
new_reduce_window_operand = computation_->AddInstruction(
HloInstruction::CreateConvert(changed_shape, pad_operand));
} else {
new_reduce_window_operand = pad_operand;
}
@ -3614,15 +3644,18 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
// We already checked feature_dimension is most minor, so data in input_shape
// and row-major {conv_width,input_channels} are bitwise identical.
const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {conv_width, input_channels});
simplifier_->UpdateLayout(&new_input_shape);
// We already checked input_feature_dimension is more major than
// output_feature_dimension, so data in filter_shape and row-major
// {input_channels,output_channels} are bitwise identical.
const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
filter_shape.element_type(), {input_channels, output_channels});
const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
simplifier_->UpdateLayout(&new_filter_shape);
Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
convolution_shape.element_type(), {conv_width, output_channels});
simplifier_->UpdateLayout(&dot_output_shape);
auto new_lhs = add_bitcast(new_input_shape, lhs);
auto new_rhs = add_bitcast(new_filter_shape, rhs);
@ -3647,8 +3680,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
convolution,
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(convolution->shape().element_type()))),
computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(convolution->shape().element_type()))),
{}));
}
@ -3731,7 +3765,7 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (AlgebraicSimplifierVisitor::Run(comp, options_)) {
if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) {
changed = true;
}
}

View File

@ -105,6 +105,15 @@ class AlgebraicSimplifier : public HloModulePass {
// computation was changed.
StatusOr<bool> Run(HloModule* module) override;
// Create constant from literal with tiles and element size updated in the
// constant's layout.
std::unique_ptr<HloInstruction> CreateConstantWithLayoutUpdated(
Literal literal) {
auto constant = HloInstruction::CreateConstant(std::move(literal));
UpdateLayout(constant->mutable_shape());
return constant;
}
private:
AlgebraicSimplifierOptions options_;
};

View File

@ -29,8 +29,11 @@ namespace xla {
class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
public:
explicit BFloat16ConversionFoldingVisitor(
HloComputation* computation, const BFloat16Support* bfloat16_support)
: computation_(computation), bfloat16_support_(bfloat16_support) {}
HloComputation* computation, const BFloat16Support* bfloat16_support,
BFloat16ConversionFolding* bfloat16_conversion_folding)
: computation_(computation),
bfloat16_support_(bfloat16_support),
bfloat16_conversion_folding_(bfloat16_conversion_folding) {}
Status DefaultAction(HloInstruction* hlo) override;
@ -38,8 +41,10 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
Status HandleAllReduce(HloInstruction* crs) override;
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
const BFloat16Support* bfloat16_support,
BFloat16ConversionFolding* bfloat16_conversion_folding) {
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support,
bfloat16_conversion_folding);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -61,6 +66,7 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
HloComputation* computation_;
const BFloat16Support* bfloat16_support_;
BFloat16ConversionFolding* bfloat16_conversion_folding_;
bool changed_ = false;
};
@ -68,6 +74,7 @@ Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
HloInstruction* hlo) {
std::vector<HloInstruction*> materialized_users = hlo->users();
hlo->mutable_shape()->set_element_type(BF16);
bfloat16_conversion_folding_->UpdateLayout(hlo->mutable_shape());
for (auto user : materialized_users) {
CHECK_EQ(user->opcode(), HloOpcode::kConvert);
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
@ -228,6 +235,8 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})
->set_element_type(BF16);
bfloat16_conversion_folding_->UpdateLayout(
ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}));
for (auto gte : per_tuple_element_gtes[i]) {
TF_RETURN_IF_ERROR(FoldOutputConversions(gte));
}
@ -241,7 +250,7 @@ StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_, this)) {
changed = true;
}
}

View File

@ -29,15 +29,20 @@ namespace xla {
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
public:
explicit BFloat16NormalizationVisitor(HloComputation* computation,
const BFloat16Support* bfloat16_support)
: computation_(computation), bfloat16_support_(bfloat16_support) {}
explicit BFloat16NormalizationVisitor(
HloComputation* computation, const BFloat16Support* bfloat16_support,
BFloat16Normalization* bfloat16_normalization)
: computation_(computation),
bfloat16_support_(bfloat16_support),
bfloat16_normalization_(bfloat16_normalization) {}
Status DefaultAction(HloInstruction* hlo) override;
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
const BFloat16Support* bfloat16_support,
BFloat16Normalization* bfloat16_normalization) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support,
bfloat16_normalization);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -73,6 +78,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
HloComputation* computation_;
const BFloat16Support* bfloat16_support_;
BFloat16Normalization* bfloat16_normalization_;
bool changed_ = false;
};
@ -95,6 +101,7 @@ Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
computation->set_root_instruction(convert);
}
convert->mutable_shape()->set_element_type(to);
bfloat16_normalization_->UpdateLayout(convert->mutable_shape());
changed_ = true;
return Status::OK();
}
@ -103,6 +110,7 @@ Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
auto original_type = hlo->shape().element_type();
hlo->mutable_shape()->set_element_type(to);
bfloat16_normalization_->UpdateLayout(hlo->mutable_shape());
return InsertConvertAfterOutput(hlo, original_type, computation);
}
@ -110,8 +118,10 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
HloComputation* computation) {
auto operand = hlo->mutable_operand(operand_idx);
auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(operand->shape(), to), operand));
auto shape = ShapeUtil::ChangeElementType(operand->shape(), to);
bfloat16_normalization_->UpdateLayout(&shape);
auto convert = computation->AddInstruction(
HloInstruction::CreateConvert(shape, operand));
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
changed_ = true;
return Status::OK();
@ -243,11 +253,13 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
continue;
}
subshape->set_element_type(F32);
bfloat16_normalization_->UpdateLayout(subshape);
auto gte = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
auto shape = ShapeUtil::ChangeElementType(*subshape, BF16);
bfloat16_normalization_->UpdateLayout(&shape);
output_elements[i] =
computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(*subshape, BF16), gte));
computation_->AddInstruction(HloInstruction::CreateConvert(shape, gte));
}
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple(output_elements));
@ -401,7 +413,7 @@ StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeComputationPostOrder()) {
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_, this)) {
changed = true;
}
}

View File

@ -674,11 +674,13 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
if (hlo->opcode() != HloOpcode::kConstant) {
continue;
}
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
hlo->shape())) {
TF_ASSIGN_OR_RETURN(auto converted_literal,
hlo->literal().ConvertToShape(hlo->shape()));
auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal)));
UpdateLayout(new_constant->mutable_shape());
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
}
}
@ -797,6 +799,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
auto subshape = entry.first;
CHECK_EQ(subshape->element_type(), F32);
subshape->set_element_type(BF16);
UpdateLayout(subshape);
changed_ = true;
}
}

View File

@ -716,7 +716,8 @@ ProgramShape HloComputation::ComputeProgramShape() const {
return program_shape;
}
bool HloComputation::operator==(const HloComputation& other) const {
bool HloComputation::Equal(const HloComputation& other,
bool is_layout_sensitive) const {
if (this == &other) {
return true;
}
@ -741,7 +742,8 @@ bool HloComputation::operator==(const HloComputation& other) const {
[](const HloInstruction*, const HloInstruction*) { return true; },
[](const HloComputation* a, const HloComputation* b) {
return *a == *b;
});
},
is_layout_sensitive);
if (!identical_ignoring_operands) {
return false;
}

View File

@ -270,7 +270,12 @@ class HloComputation {
ProgramShape ComputeProgramShape() const;
// Return whether `*this` and `other` are functionally equivalent.
bool operator==(const HloComputation& other) const;
bool Equal(const HloComputation& other, bool is_layout_sensitive) const;
// Return whether `*this` and `other` are functionally equivalent.
bool operator==(const HloComputation& other) const {
return Equal(other, true);
}
// Replaces old instruction with newly created instruction. Removes old
// instruction from computation. Updates uses and root instruction.

View File

@ -1804,8 +1804,9 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
// Out of convenience the literal may have been produced with a different
// layout. Relayout as indicated by the HLO instruction.
if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
hlo->shape())) {
if (!Layout::Equal().MinorToMajorOnly()(
GetEvaluatedLiteralFor(hlo).shape().layout(),
hlo->shape().layout())) {
evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
}
return Status::OK();

View File

@ -303,6 +303,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
instruction = CreateConstant(std::move(literal));
// Literal's shape may have no/different tiling info.
CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(instruction->shape(),
shape));
*instruction->mutable_shape() = shape;
} else {
instruction = absl::make_unique<HloConstantInstruction>(shape);
}

View File

@ -1018,6 +1018,11 @@ HloConstantInstruction::HloConstantInstruction(Literal literal)
: HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(Literal literal,
const Shape& shape)
: HloInstruction(HloOpcode::kConstant, shape),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
: HloInstruction(HloOpcode::kConstant, shape) {}
@ -1063,7 +1068,12 @@ HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK(literal_.has_value());
return absl::make_unique<HloConstantInstruction>(literal_->Clone());
// Literal's shape may have no/different tiling info. Use this instruction's
// shape instead.
CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(),
this->shape()));
return absl::make_unique<HloConstantInstruction>(literal_->Clone(),
this->shape());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(

View File

@ -650,6 +650,7 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
explicit HloConstantInstruction(Literal literal);
explicit HloConstantInstruction(Literal literal, const Shape& shape);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.

View File

@ -56,6 +56,14 @@ class HloModulePass : public HloPassInterface {
}
return changed;
};
// Update the layout of a Shape to one that is supported by a given backend.
// One can call this function after modifying the Shape in case that modifying
// the Shape requires changes to the layout for the given Backend.
//
// TODO(b/129084868): Make this Backend dependent instead of requiring
// deriving from the pass the and overriding this function.
virtual void UpdateLayout(Shape* shape) {}
};
// Base class for passes which are module-group scoped. These passes cannot run

View File

@ -343,7 +343,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
// Check that the 'compare' computation returns a PRED.
Shape compare_shape = compare->root_instruction()->shape();
if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"The Sort compare computation shape does not lead to a scalar "
"predicate shape: %s",
@ -393,7 +393,8 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
return InternalError("Constant is required to have a valid literal: %s",
constant->ToString());
}
return CheckShape(constant, constant->literal().shape());
return CheckShape(constant, constant->literal().shape(),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
@ -654,7 +655,8 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
if (!ShapeUtil::Compatible(conditional_shape,
ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
@ -696,7 +698,8 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) {
return CheckShape(send,
ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()}));
ShapeUtil::MakeTokenShape()}),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
@ -705,9 +708,11 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
return CheckShape(
recv, ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(recv->shape(), 0),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}));
recv,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(recv->shape(), 0),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
@ -844,7 +849,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape) {
const Shape& inferred_shape,
bool only_compare_minor_to_major_in_layout) {
// If allow_mixed_precision_ is false, check if there are operands with
// different precisions. We need this check because ShapeInference allows
// mixed precision inputs.
@ -878,7 +884,8 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
return ShapesSame(instruction->shape(), inferred_shape);
return ShapesSame(instruction->shape(), inferred_shape,
only_compare_minor_to_major_in_layout);
// We allow arbitrary layout and f32->bf16 transformations on all other
// instructions, although this may be made more strict pending discussion

View File

@ -106,7 +106,8 @@ class ShapeVerifier : public DfsHloVisitor {
// Check the instruction's shape against the shape given by ShapeInference
// and return an appropriate error if there is a mismatch.
Status CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape);
const Shape& inferred_shape,
bool only_compare_minor_to_major_in_layout = false);
// Overload which takes a StatusOr to reduce boilerplate in the caller.
Status CheckShape(const HloInstruction* instruction,
@ -120,14 +121,31 @@ class ShapeVerifier : public DfsHloVisitor {
private:
// Helpers that switch on layout_sensitive_.
bool ShapesSame(const Shape& a, const Shape& b) {
return layout_sensitive_ ? ShapeUtil::Equal(a, b)
: ShapeUtil::Compatible(a, b);
bool ShapesSame(const Shape& a, const Shape& b,
bool minor_to_major_only = false) {
if (!layout_sensitive_) {
return ShapeUtil::Compatible(a, b);
}
Shape::Equal equal;
if (minor_to_major_only) {
equal.MinorToMajorOnlyInLayout();
}
return equal(a, b);
}
bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
: ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b,
bool minor_to_major_only = false) {
if (!layout_sensitive_) {
return ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
}
Shape::Equal equal;
if (minor_to_major_only) {
equal.MinorToMajorOnlyInLayout();
}
equal.IgnoreFpPrecision();
return equal(a, b);
}
string StringifyShape(const Shape& s) {
return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
: ShapeUtil::HumanString(s);

View File

@ -173,7 +173,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
auto iter = buffer_constraints_.find(&buffer);
if (iter != buffer_constraints_.end()) {
const BufferLayoutConstraint& curr_constraint = iter->second;
if (LayoutUtil::Equal(curr_constraint.layout(), layout)) {
if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
}
@ -210,7 +210,7 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
GetOperandLayoutConstraint(instruction, operand_no);
if (curr_shape_layout != nullptr) {
if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
shape_with_layout)) {
shape_with_layout, /*minor_to_major_only=*/true)) {
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
}
@ -269,7 +269,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
const ShapeLayout* curr_shape_layout = ResultLayout();
if (curr_shape_layout != nullptr) {
if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
if (!curr_shape_layout->MatchesLayoutInShape(
shape_with_layout, /*minor_to_major_only=*/true)) {
return FailedPrecondition(
"Result of computation %s already has the layout constraint %s, "
"cannot add incompatible constraint %s",
@ -647,6 +648,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
namespace {
bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
}
// The operands of a call must match the layouts of parameters in the
// ComputationLayout, and the call instruction itself must match the result
// layout in the ComputationLayout.
@ -656,10 +661,10 @@ Status CheckCallLayout(HloInstruction* call,
TF_RET_CHECK(computation->num_parameters() == call->operand_count());
for (int64 i = 0; i < computation->num_parameters(); ++i) {
TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
call->operand(i)->shape()));
call->operand(i)->shape(), /*minor_to_major_only=*/true));
}
TF_RET_CHECK(
computation_layout.result_layout().MatchesLayoutInShape(call->shape()));
TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
call->shape(), /*minor_to_major_only=*/true));
return Status::OK();
}
@ -670,9 +675,9 @@ Status CheckCustomCallLayout(HloInstruction* instruction) {
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);
for (int64 i = 0; i < custom_call->operand_count(); ++i) {
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
custom_call->operand(i)->shape(),
custom_call->operand_shapes_with_layout()[i]));
TF_RET_CHECK(
LayoutsInShapesEqual(custom_call->operand(i)->shape(),
custom_call->operand_shapes_with_layout()[i]));
}
}
return Status::OK();
@ -690,13 +695,12 @@ Status CheckWhileLayout(HloInstruction* while_inst,
auto init_shape = while_inst->operand(0)->shape();
TF_RET_CHECK(
condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
init_shape));
init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
init_shape));
TF_RET_CHECK(
body_computation_layout.result_layout().MatchesLayoutInShape(init_shape));
TF_RET_CHECK(
LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
init_shape, /*minor_to_major_only=*/true));
TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
return Status::OK();
}
@ -709,13 +713,14 @@ Status CheckConditionalLayout(
branch_computation_layouts[j].result_layout());
TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->shape()));
instruction->shape(), /*minor_to_major_only=*/true));
TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->branch_computation(j)->root_instruction()->shape()));
instruction->branch_computation(j)->root_instruction()->shape(),
/*minor_to_major_only=*/true));
TF_RET_CHECK(
branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
branch_operand->shape()));
branch_operand->shape(), /*minor_to_major_only=*/true));
}
return Status::OK();
}
@ -726,11 +731,11 @@ Status CheckConditionalLayout(
Status CheckFusionLayout(HloInstruction* fusion) {
TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
fusion->shape(), fusion->fused_expression_root()->shape()));
TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
fusion->fused_expression_root()->shape()));
for (int64 i = 0; i < fusion->operand_count(); ++i) {
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape()));
TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
fusion->operand(i)->shape()));
}
return Status::OK();
}
@ -742,7 +747,8 @@ Status CheckParameterLayout(HloInstruction* parameter,
const ShapeLayout& parameter_layout =
computation_layout.parameter_layout(parameter->parameter_number());
if (parameter_layout.LayoutIsSet() &&
!parameter_layout.MatchesLayoutInShape(parameter->shape())) {
!parameter_layout.MatchesLayoutInShape(parameter->shape(),
/*minor_to_major_only=*/true)) {
return InternalError(
"parameter instruction %s does not match layout of computation "
"shape: %s",
@ -753,8 +759,7 @@ Status CheckParameterLayout(HloInstruction* parameter,
// The layout of a constant instruction must match the layout of its literal.
Status CheckConstantLayout(HloInstruction* constant) {
if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(),
constant->shape())) {
if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
return InternalError(
"constant instruction %s does not match the layout of its literal %s",
constant->ToString(),
@ -785,7 +790,8 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
HloInstruction* gte = instruction->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
if (ShapeUtil::Equal(target_shape, instr_shape)) {
if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
instr_shape)) {
// Shapes and layouts are equal, no need to copy.
element_copies.push_back(gte);
} else {
@ -831,7 +837,8 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
TF_RET_CHECK(operand_layout.LayoutIsSet());
TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
operand->shape())) {
VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
<< instruction->ToString();
// Operand layout already matches our constraint. Nothing to do.
@ -892,7 +899,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
const Shape& instruction_subshape =
ShapeUtil::GetSubshape(instruction->shape(), index);
for (const LogicalBuffer* buffer : buffers) {
if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) {
if (!Shape::Equal().MinorToMajorOnlyInLayout()(
instruction_subshape, buffer->shape())) {
return InternalError(
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
@ -954,8 +962,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, module->entry_computation())
.result_layout();
if (result_layout.LayoutIsSet()) {
TF_RET_CHECK(
ShapeUtil::Equal(module->result_shape(), result_layout.shape()));
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
module->result_shape(), result_layout.shape()));
}
return Status::OK();
}
@ -1510,8 +1518,8 @@ StatusOr<Layout> InferArrayLayout(
if (first_buffer_layout == nullptr) {
first_buffer_layout = &source_buffer->shape().layout();
} else if (!LayoutUtil::Equal(source_buffer->shape().layout(),
*first_buffer_layout)) {
} else if (!Layout::Equal().MinorToMajorOnly()(
source_buffer->shape().layout(), *first_buffer_layout)) {
// The points-to set is ambiguous for this index and the different source
// buffers have different layouts. This case is possible in valid XLA
// computations because we do not propagate BufferLayoutConstraints to all
@ -1789,7 +1797,8 @@ Status LayoutAssignment::RunOnComputation(
// layout constraint.
if (constraints.ResultLayout() != nullptr &&
!constraints.ResultLayout()->MatchesLayoutInShape(
computation->root_instruction()->shape())) {
computation->root_instruction()->shape(),
/*minor_to_major_only=*/true)) {
if (conditional_mismatch_.count(computation) > 0) {
*FindOrDie(computation_layouts_, computation).mutable_result_layout() =
FindOrDie(conditional_mismatch_, computation).result_layout();
@ -1907,7 +1916,9 @@ Status LayoutAssignment::PropagateComputationLayouts(
<< ": " << computed_computation_layout.result_layout().ToString();
*result_layout = computed_computation_layout.result_layout();
} else {
TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout);
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
computed_computation_layout.result_layout().shape(),
result_layout->shape()));
}
return Status::OK();
}

View File

@ -141,6 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
}
}
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
if (!ignore_layout_) {
if (lhs.layout().format() != rhs.layout().format()) {
VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
@ -161,11 +166,6 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
}
}
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
if (!ignore_dynamic_dimension_) {
for (int i = 0; i < lhs.rank(); ++i) {
if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) {

View File

@ -169,6 +169,11 @@ class Shape {
ignore_element_size_in_layout_ = true;
return *this;
}
Equal& MinorToMajorOnlyInLayout() {
ignore_tiles_in_layout_ = true;
ignore_element_size_in_layout_ = true;
return *this;
}
Equal& IgnoreElementType() {
ignore_element_type_ = true;
return *this;

View File

@ -46,8 +46,13 @@ void ShapeLayout::SetToDefaultLayout() {
LayoutUtil::SetToDefaultLayout(&shape_);
}
bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const {
return ShapeUtil::Equal(shape, shape_);
bool ShapeLayout::MatchesLayoutInShape(const Shape& shape,
bool minor_to_major_only) const {
auto equal = Shape::Equal();
if (minor_to_major_only) {
equal.MinorToMajorOnlyInLayout();
}
return equal(shape, shape_);
}
const Layout& ShapeLayout::layout() const {

View File

@ -45,7 +45,8 @@ class ShapeLayout {
// Returns true if the Layouts in this ShapeLayout match the layouts in the
// given shape. Returns false otherwise. If the given shape is not compatible
// with the ShapeLayout's shape, then false is returned.
bool MatchesLayoutInShape(const Shape& shape) const;
bool MatchesLayoutInShape(const Shape& shape,
bool minor_to_major_only = false) const;
// Copies the layout from the given shape into this ShapeLayout. 'other_shape'
// must be compatible with the ShapeLayout's shape.

View File

@ -102,6 +102,11 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
}
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeUtil::MakeValidatedShape(element_type, dimensions));
if (element_size_in_bits ==
ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
// Only set element_size_in_bits if it's different from the default value.
element_size_in_bits = 0;
}
*shape.mutable_layout() =
LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits);
if (!shape.has_layout()) {
@ -219,7 +224,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
for (int i = 0; i < shape.dimensions_size(); ++i) {
dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
}
return MakeShapeWithDescendingLayout(shape.element_type(), dims);
Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
// Since the physical layout is kept the same, the tiles and element size are
// the same also.
*new_shape.mutable_layout()->mutable_tiles() = shape.layout().tiles();
new_shape.mutable_layout()->set_element_size_in_bits(
shape.layout().element_size_in_bits());
return new_shape;
}
/* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,

View File

@ -395,6 +395,8 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
ParametricDotTestWithoutLayoutAssignment() {
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"layout-assignment");
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"tiling-assignment");
// Disable algebraic simplification because the pass may replace a dot
// instruction with a layout-changing multiplication instruction.
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(

View File

@ -846,16 +846,16 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
xla::ProgramShape xla_program_shape =
XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
EXPECT_TRUE(xla::LayoutUtil::Equal(
EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
.layout()));
EXPECT_TRUE(xla::LayoutUtil::Equal(
EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
.layout()));
EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(),
xla_program_shape.result().layout()));
EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
program_shape.result().layout(), xla_program_shape.result().layout()));
}
TEST(RawApiTest, DotGeneralWithLayoutTest) {