diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc index 000c4fdc405..7052ec09f35 100644 --- a/tensorflow/compiler/xla/layout.cc +++ b/tensorflow/compiler/xla/layout.cc @@ -96,8 +96,13 @@ string Layout::ToString() const { } bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { - if (lhs.format() != rhs.format() || - lhs.minor_to_major() != rhs.minor_to_major() || + if (lhs.format() != rhs.format()) { + return false; + } + if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) { + return false; + } + if (lhs.format() == SPARSE && lhs.max_sparse_elements() != rhs.max_sparse_elements()) { return false; } diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index acc449b781b..63b2a566535 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -127,6 +127,12 @@ class Layout { return *this; } + Equal& MinorToMajorOnly() { + ignore_tiles_ = true; + ignore_element_size_ = true; + return *this; + } + private: bool ignore_tiles_ = false; bool ignore_element_size_ = false; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index b223fc8b1b5..b3044504312 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -250,12 +250,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, - const AlgebraicSimplifierOptions& options); + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier); private: explicit AlgebraicSimplifierVisitor(HloComputation* computation, - const AlgebraicSimplifierOptions& options) - : computation_(computation), options_(options) {} + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier) + : computation_(computation), options_(options), simplifier_(simplifier) {} // Transforms Dots where at least one input is a vector or has a degenerate // dimension and converts it into a multiply and reduce. This should enable @@ -274,10 +276,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { if (hlo->shape().rank() == 1) { return hlo; } - return computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(hlo->shape().element_type(), - {ShapeUtil::ElementsIn(hlo->shape())}), - hlo)); + auto hlo_instruction = + computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(hlo->shape().element_type(), + {ShapeUtil::ElementsIn(hlo->shape())}), + hlo)); + simplifier_->UpdateLayout(hlo_instruction->mutable_shape()); + return hlo_instruction; } // Converts to primitive type if the input hlo is not that type, otherwise @@ -287,8 +292,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { if (hlo->shape().element_type() == element_type) { return hlo; } - return computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + Shape changed_shape = + ShapeUtil::ChangeElementType(hlo->shape(), element_type); + simplifier_->UpdateLayout(&changed_shape); + return computation_->AddInstruction( + HloInstruction::CreateConvert(changed_shape, hlo)); } // Transposes a dot operand such that the batch dimensions are the msot major, @@ -312,13 +320,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Helper method to perform and add reduction on a list of dimensions. HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims) { - HloInstruction* zero = - computation_->AddInstruction(HloInstruction::CreateConstant( + HloInstruction* zero = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::FilterDimensions( [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, hlo->shape()); + simplifier_->UpdateLayout(&shape); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, dims, AddReduce_computation)); } @@ -403,6 +412,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloComputation::Builder b("scalar_add_computation"); Shape shape = ShapeUtil::MakeShape(F32, {}); + simplifier_->UpdateLayout(&shape); auto scalar_lhs = b.AddInstruction( HloInstruction::CreateParameter(0, shape, "scalar_lhs")); auto scalar_rhs = b.AddInstruction( @@ -440,13 +450,16 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Cached computation for adding two scalar F32. HloComputation* scalar_add_computation_ = nullptr; + + AlgebraicSimplifier* simplifier_ = nullptr; }; } // namespace -bool AlgebraicSimplifierVisitor::Run( - HloComputation* computation, const AlgebraicSimplifierOptions& options) { - AlgebraicSimplifierVisitor visitor(computation, options); +bool AlgebraicSimplifierVisitor::Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier) { + AlgebraicSimplifierVisitor visitor(computation, options, simplifier); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -713,6 +726,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( new_slice_shape.set_dimensions( concatenate_dimension, slice_end - operands[i]->slice_starts(concatenate_dimension)); + simplifier_->UpdateLayout(&new_slice_shape); auto new_limit_indices = operands[i]->slice_limits(); new_limit_indices[concatenate_dimension] = slice_end; auto new_slice_op = @@ -775,18 +789,19 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const LiteralSlice& literal) { + const LiteralSlice& literal, + AlgebraicSimplifier* simplifier) { if (literal.shape().IsTuple()) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { - elems.push_back( - BuildTupleConstant(computation, LiteralSlice(literal, {i}))); + elems.push_back(BuildTupleConstant( + computation, LiteralSlice(literal, {i}), simplifier)); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.Clone())); + simplifier->CreateConstantWithLayoutUpdated(literal.Clone())); } } @@ -795,7 +810,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // explicit Tuple instructions. if (constant->shape().IsTuple()) { return ReplaceInstruction( - constant, BuildTupleConstant(computation_, constant->literal())); + constant, + BuildTupleConstant(computation_, constant->literal(), simplifier_)); } if (constant->shape().element_type() == TOKEN) { @@ -808,7 +824,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(unique_scalar))); + simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar))); return ReplaceWithNewInstruction( constant, HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); @@ -854,8 +870,9 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) { } template -std::unique_ptr TryDivideToShift(HloInstruction* divide, - HloComputation* computation) { +std::unique_ptr TryDivideToShift( + HloInstruction* divide, HloComputation* computation, + AlgebraicSimplifier* simplifier) { HloInstruction *a, *b, *c; CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); @@ -872,10 +889,11 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, HloInstruction* zero_like_a = BroadcastZeros( computation, a->shape().element_type(), a->shape().dimensions()); + Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED); + simplifier->UpdateLayout(&changed_shape); auto* dividend_is_negative = computation->AddInstruction(HloInstruction::CreateCompare( - ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, - ComparisonDirection::kLt)); + changed_shape, a, zero_like_a, ComparisonDirection::kLt)); auto* negated_dividend = computation->AddInstruction( HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); @@ -887,8 +905,8 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, int log2_abs_b_value = tensorflow::Log2Floor64(b_value); - auto* shift_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + auto* shift_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(log2_abs_b_value))); if (!ShapeUtil::IsScalar(b->shape())) { shift_amount = computation->AddInstruction( @@ -911,8 +929,8 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, uint64 b_value = c->literal().GetFirstElement(); if (IsPowerOfTwo(b_value)) { int log2_abs_b_value = tensorflow::Log2Floor64(b_value); - HloInstruction* shift_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + HloInstruction* shift_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(log2_abs_b_value))); if (!ShapeUtil::IsScalar(b->shape())) { shift_amount = computation->AddInstruction( @@ -940,49 +958,49 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { switch (divide->shape().element_type()) { case S8: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case S16: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case S32: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case S64: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U8: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U16: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U32: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U64: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; @@ -1084,7 +1102,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.Clone()))); + simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone()))); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -1456,6 +1474,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); Shape rhs_slice_shape(rhs->shape()); rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); + simplifier_->UpdateLayout(&rhs_slice_shape); std::array start_indices; start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; @@ -1591,6 +1610,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( right_operand->shape().dimensions(1 - rhs_contracting_dimension); auto memoized_shape = ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); + simplifier_->UpdateLayout(&memoized_shape); auto* memoized_inst = computation_->AddInstruction( HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, dnums, dot->precision_config())); @@ -1605,7 +1625,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // Slice out start and 0 components and reorder if necessary. auto indices_type = dynamic_slice->operand(1)->shape().element_type(); Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); + simplifier_->UpdateLayout(&s_shape); Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); + simplifier_->UpdateLayout(&d_shape); HloInstruction* non_zero_start = dynamic_slice->mutable_operand(1 + index_of_non_zero_start); HloInstruction* zero_start = @@ -1638,8 +1660,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { - auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(dot->shape().element_type()))); + auto zero = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::Zero(dot->shape().element_type()))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -2183,8 +2206,9 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { // zero. auto* iota = Cast(instruction); if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { - auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(iota->shape().element_type()).Clone())); + auto zero = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::Zero(iota->shape().element_type()).Clone())); return ReplaceWithNewInstruction( iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); } @@ -2307,7 +2331,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { HloInstruction *lhs, *rhs; CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { - auto one = HloInstruction::CreateConstant( + auto one = simplifier_->CreateConstantWithLayoutUpdated( LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { @@ -2342,8 +2366,9 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { - auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).Clone())); + auto* one = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -2422,14 +2447,16 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( std::vector new_operands; new_operands.reserve(user->operand_count()); + Shape changed_shape; for (HloInstruction* user_operand : user->operands()) { if (user_operand->opcode() == HloOpcode::kBroadcast && ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + changed_shape = ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); new_operands.push_back( computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::ChangeElementType( - operand->shape(), user_operand->shape().element_type()), - user_operand->mutable_operand(0), {}))); + changed_shape, user_operand->mutable_operand(0), {}))); } else { CHECK_EQ(broadcast, user_operand); new_operands.push_back(operand); @@ -2438,11 +2465,11 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( VLOG(4) << "Sinking broadcast after user:"; VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - HloInstruction* new_user = - computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::ChangeElementType(operand->shape(), - user->shape().element_type()), - new_operands)); + changed_shape = ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); + HloInstruction* new_user = computation_->AddInstruction( + user->CloneWithNewOperands(changed_shape, new_operands)); VLOG(4) << " new user: " << new_user->ToString(); HloInstruction* new_broadcast = computation_->AddInstruction(HloInstruction::CreateBroadcast( @@ -2456,8 +2483,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( namespace { template -std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, - HloComputation* computation) { +std::unique_ptr TryRemainderToAnd( + HloInstruction* remainder, HloComputation* computation, + AlgebraicSimplifier* simplifier) { HloInstruction *a, *b, *c; CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); @@ -2487,8 +2515,8 @@ std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, a->shape(), HloOpcode::kSelect, dividend_is_negative, negated_dividend, a)); - auto* mask_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + auto* mask_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(b_value - 1))); if (!ShapeUtil::IsScalar(b->shape())) { mask_amount = computation->AddInstruction( @@ -2509,8 +2537,8 @@ std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, } else { uint64 b_value = c->literal().GetFirstElement(); if (IsPowerOfTwo(b_value)) { - HloInstruction* mask_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + HloInstruction* mask_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(b_value - 1))); if (!ShapeUtil::IsScalar(b->shape())) { mask_amount = computation->AddInstruction( @@ -2532,49 +2560,49 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { switch (remainder->shape().element_type()) { case S8: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case S16: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case S32: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case S64: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U8: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U16: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U32: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U64: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; @@ -2597,7 +2625,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { if (!LayoutUtil::HasLayout(reshaped_shape)) { LayoutUtil::SetToDefaultLayout(&reshaped_shape); } - auto empty_constant = HloInstruction::CreateConstant( + auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated( Literal::CreateFromShape(reshaped_shape)); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); @@ -2810,6 +2838,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( new_slice_limits), new_slice_operand, new_slice_starts, new_slice_limits, new_slice_stides)); + simplifier_->UpdateLayout(new_slice->mutable_shape()); TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); return true; @@ -2914,8 +2943,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // representative. auto arg = reduce->inputs()[0]; auto init_value = reduce->init_values()[0]; - const Shape& reduce_result_shape = - multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); + Shape& reduce_result_shape = const_cast( + multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape()); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); @@ -3136,6 +3165,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return !absl::c_linear_search(effective_reduce_dims, dim); }, reduce_window->shape()); + simplifier_->UpdateLayout(&reduce_shape); HloInstruction* reduce = computation_->AddInstruction(HloInstruction::CreateReduce( /*shape=*/reduce_shape, @@ -3261,11 +3291,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* new_reduce_window_operand; if (convert != nullptr) { - new_reduce_window_operand = - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(pad_operand->shape(), - convert->shape().element_type()), - pad_operand)); + Shape changed_shape = ShapeUtil::ChangeElementType( + pad_operand->shape(), convert->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); + new_reduce_window_operand = computation_->AddInstruction( + HloInstruction::CreateConvert(changed_shape, pad_operand)); } else { new_reduce_window_operand = pad_operand; } @@ -3614,15 +3644,18 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( // We already checked feature_dimension is most minor, so data in input_shape // and row-major {conv_width,input_channels} are bitwise identical. - const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( + Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( input_shape.element_type(), {conv_width, input_channels}); + simplifier_->UpdateLayout(&new_input_shape); // We already checked input_feature_dimension is more major than // output_feature_dimension, so data in filter_shape and row-major // {input_channels,output_channels} are bitwise identical. - const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( + Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( filter_shape.element_type(), {input_channels, output_channels}); - const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + simplifier_->UpdateLayout(&new_filter_shape); + Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( convolution_shape.element_type(), {conv_width, output_channels}); + simplifier_->UpdateLayout(&dot_output_shape); auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); @@ -3647,8 +3680,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( convolution, HloInstruction::CreateBroadcast( convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()))), + computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::Zero(convolution->shape().element_type()))), {})); } @@ -3731,7 +3765,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run(comp, options_)) { + if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index df5a8c2ec14..1768f725b20 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -105,6 +105,15 @@ class AlgebraicSimplifier : public HloModulePass { // computation was changed. StatusOr Run(HloModule* module) override; + // Create constant from literal with tiles and element size updated in the + // constant's layout. + std::unique_ptr CreateConstantWithLayoutUpdated( + Literal literal) { + auto constant = HloInstruction::CreateConstant(std::move(literal)); + UpdateLayout(constant->mutable_shape()); + return constant; + } + private: AlgebraicSimplifierOptions options_; }; diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index e62d72b323b..0a8e8dc2a8b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -29,8 +29,11 @@ namespace xla { class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { public: explicit BFloat16ConversionFoldingVisitor( - HloComputation* computation, const BFloat16Support* bfloat16_support) - : computation_(computation), bfloat16_support_(bfloat16_support) {} + HloComputation* computation, const BFloat16Support* bfloat16_support, + BFloat16ConversionFolding* bfloat16_conversion_folding) + : computation_(computation), + bfloat16_support_(bfloat16_support), + bfloat16_conversion_folding_(bfloat16_conversion_folding) {} Status DefaultAction(HloInstruction* hlo) override; @@ -38,8 +41,10 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status HandleAllReduce(HloInstruction* crs) override; static bool Run(HloComputation* computation, - const BFloat16Support* bfloat16_support) { - BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); + const BFloat16Support* bfloat16_support, + BFloat16ConversionFolding* bfloat16_conversion_folding) { + BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support, + bfloat16_conversion_folding); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -61,6 +66,7 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { HloComputation* computation_; const BFloat16Support* bfloat16_support_; + BFloat16ConversionFolding* bfloat16_conversion_folding_; bool changed_ = false; }; @@ -68,6 +74,7 @@ Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( HloInstruction* hlo) { std::vector materialized_users = hlo->users(); hlo->mutable_shape()->set_element_type(BF16); + bfloat16_conversion_folding_->UpdateLayout(hlo->mutable_shape()); for (auto user : materialized_users) { CHECK_EQ(user->opcode(), HloOpcode::kConvert); TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); @@ -228,6 +235,8 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}) ->set_element_type(BF16); + bfloat16_conversion_folding_->UpdateLayout( + ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})); for (auto gte : per_tuple_element_gtes[i]) { TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); } @@ -241,7 +250,7 @@ StatusOr BFloat16ConversionFolding::Run(HloModule* module) { 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { + if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_, this)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 72459961485..e7b4b6ae100 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -29,15 +29,20 @@ namespace xla { class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { public: - explicit BFloat16NormalizationVisitor(HloComputation* computation, - const BFloat16Support* bfloat16_support) - : computation_(computation), bfloat16_support_(bfloat16_support) {} + explicit BFloat16NormalizationVisitor( + HloComputation* computation, const BFloat16Support* bfloat16_support, + BFloat16Normalization* bfloat16_normalization) + : computation_(computation), + bfloat16_support_(bfloat16_support), + bfloat16_normalization_(bfloat16_normalization) {} Status DefaultAction(HloInstruction* hlo) override; static bool Run(HloComputation* computation, - const BFloat16Support* bfloat16_support) { - BFloat16NormalizationVisitor visitor(computation, bfloat16_support); + const BFloat16Support* bfloat16_support, + BFloat16Normalization* bfloat16_normalization) { + BFloat16NormalizationVisitor visitor(computation, bfloat16_support, + bfloat16_normalization); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -73,6 +78,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { HloComputation* computation_; const BFloat16Support* bfloat16_support_; + BFloat16Normalization* bfloat16_normalization_; bool changed_ = false; }; @@ -95,6 +101,7 @@ Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( computation->set_root_instruction(convert); } convert->mutable_shape()->set_element_type(to); + bfloat16_normalization_->UpdateLayout(convert->mutable_shape()); changed_ = true; return Status::OK(); } @@ -103,6 +110,7 @@ Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { auto original_type = hlo->shape().element_type(); hlo->mutable_shape()->set_element_type(to); + bfloat16_normalization_->UpdateLayout(hlo->mutable_shape()); return InsertConvertAfterOutput(hlo, original_type, computation); } @@ -110,8 +118,10 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( HloInstruction* hlo, int64 operand_idx, PrimitiveType to, HloComputation* computation) { auto operand = hlo->mutable_operand(operand_idx); - auto convert = computation->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(operand->shape(), to), operand)); + auto shape = ShapeUtil::ChangeElementType(operand->shape(), to); + bfloat16_normalization_->UpdateLayout(&shape); + auto convert = computation->AddInstruction( + HloInstruction::CreateConvert(shape, operand)); TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); changed_ = true; return Status::OK(); @@ -243,11 +253,13 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( continue; } subshape->set_element_type(F32); + bfloat16_normalization_->UpdateLayout(subshape); auto gte = computation_->AddInstruction( HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); + auto shape = ShapeUtil::ChangeElementType(*subshape, BF16); + bfloat16_normalization_->UpdateLayout(&shape); output_elements[i] = - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(*subshape, BF16), gte)); + computation_->AddInstruction(HloInstruction::CreateConvert(shape, gte)); } auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); @@ -401,7 +413,7 @@ StatusOr BFloat16Normalization::Run(HloModule* module) { 2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeComputationPostOrder()) { - if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) { + if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_, this)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index bab63f66d83..d314065c752 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -674,11 +674,13 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { if (hlo->opcode() != HloOpcode::kConstant) { continue; } - if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { + if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(), + hlo->shape())) { TF_ASSIGN_OR_RETURN(auto converted_literal, hlo->literal().ConvertToShape(hlo->shape())); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); + UpdateLayout(new_constant->mutable_shape()); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); } } @@ -797,6 +799,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { auto subshape = entry.first; CHECK_EQ(subshape->element_type(), F32); subshape->set_element_type(BF16); + UpdateLayout(subshape); changed_ = true; } } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 48a51d302bb..195c84b034f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -716,7 +716,8 @@ ProgramShape HloComputation::ComputeProgramShape() const { return program_shape; } -bool HloComputation::operator==(const HloComputation& other) const { +bool HloComputation::Equal(const HloComputation& other, + bool is_layout_sensitive) const { if (this == &other) { return true; } @@ -741,7 +742,8 @@ bool HloComputation::operator==(const HloComputation& other) const { [](const HloInstruction*, const HloInstruction*) { return true; }, [](const HloComputation* a, const HloComputation* b) { return *a == *b; - }); + }, + is_layout_sensitive); if (!identical_ignoring_operands) { return false; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index a48cfa1f1b2..e42808be773 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -270,7 +270,12 @@ class HloComputation { ProgramShape ComputeProgramShape() const; // Return whether `*this` and `other` are functionally equivalent. - bool operator==(const HloComputation& other) const; + bool Equal(const HloComputation& other, bool is_layout_sensitive) const; + + // Return whether `*this` and `other` are functionally equivalent. + bool operator==(const HloComputation& other) const { + return Equal(other, true); + } // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index b3983fc696c..61b1d0012f3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1804,8 +1804,9 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); // Out of convenience the literal may have been produced with a different // layout. Relayout as indicated by the HLO instruction. - if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), - hlo->shape())) { + if (!Layout::Equal().MinorToMajorOnly()( + GetEvaluatedLiteralFor(hlo).shape().layout(), + hlo->shape().layout())) { evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9f28ea3255d..cdd02d2c48c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -303,6 +303,10 @@ StatusOr> HloInstruction::CreateFromProto( TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); + // Literal's shape may have no/different tiling info. + CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(instruction->shape(), + shape)); + *instruction->mutable_shape() = shape; } else { instruction = absl::make_unique(shape); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 35b36acbcc4..9f9dde934dd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1018,6 +1018,11 @@ HloConstantInstruction::HloConstantInstruction(Literal literal) : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} +HloConstantInstruction::HloConstantInstruction(Literal literal, + const Shape& shape) + : HloInstruction(HloOpcode::kConstant, shape), + literal_(std::move(literal)) {} + HloConstantInstruction::HloConstantInstruction(const Shape& shape) : HloInstruction(HloOpcode::kConstant, shape) {} @@ -1063,7 +1068,12 @@ HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK(literal_.has_value()); - return absl::make_unique(literal_->Clone()); + // Literal's shape may have no/different tiling info. Use this instruction's + // shape instead. + CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(), + this->shape())); + return absl::make_unique(literal_->Clone(), + this->shape()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index aa0652f21e3..e6576923ff7 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -650,6 +650,7 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: explicit HloConstantInstruction(Literal literal); + explicit HloConstantInstruction(Literal literal, const Shape& shape); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index fdaac34386c..26f386ce3c6 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -56,6 +56,14 @@ class HloModulePass : public HloPassInterface { } return changed; }; + + // Update the layout of a Shape to one that is supported by a given backend. + // One can call this function after modifying the Shape in case that modifying + // the Shape requires changes to the layout for the given Backend. + // + // TODO(b/129084868): Make this Backend dependent instead of requiring + // deriving from the pass the and overriding this function. + virtual void UpdateLayout(Shape* shape) {} }; // Base class for passes which are module-group scoped. These passes cannot run diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 375ae2c477d..bba3be4c2e6 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -343,7 +343,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { // Check that the 'compare' computation returns a PRED. Shape compare_shape = compare->root_instruction()->shape(); - if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "The Sort compare computation shape does not lead to a scalar " "predicate shape: %s", @@ -393,7 +393,8 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return InternalError("Constant is required to have a valid literal: %s", constant->ToString()); } - return CheckShape(constant, constant->literal().shape()); + return CheckShape(constant, constant->literal().shape(), + /*only_compare_minor_to_major_in_layout=*/true); } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { @@ -654,7 +655,8 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapeUtil::Compatible(conditional_shape, + ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", @@ -696,7 +698,8 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), - ShapeUtil::MakeTokenShape()})); + ShapeUtil::MakeTokenShape()}), + /*only_compare_minor_to_major_in_layout=*/true); } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { @@ -705,9 +708,11 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { Status ShapeVerifier::HandleRecv(HloInstruction* recv) { return CheckShape( - recv, ShapeUtil::MakeTupleShape( - {ShapeUtil::GetTupleElementShape(recv->shape(), 0), - ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})); + recv, + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv->shape(), 0), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}), + /*only_compare_minor_to_major_in_layout=*/true); } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { @@ -844,7 +849,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const Shape& inferred_shape) { + const Shape& inferred_shape, + bool only_compare_minor_to_major_in_layout) { // If allow_mixed_precision_ is false, check if there are operands with // different precisions. We need this check because ShapeInference allows // mixed precision inputs. @@ -878,7 +884,8 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, case HloOpcode::kTuple: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: - return ShapesSame(instruction->shape(), inferred_shape); + return ShapesSame(instruction->shape(), inferred_shape, + only_compare_minor_to_major_in_layout); // We allow arbitrary layout and f32->bf16 transformations on all other // instructions, although this may be made more strict pending discussion diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index d427a1586c3..a38ec5a05d4 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -106,7 +106,8 @@ class ShapeVerifier : public DfsHloVisitor { // Check the instruction's shape against the shape given by ShapeInference // and return an appropriate error if there is a mismatch. Status CheckShape(const HloInstruction* instruction, - const Shape& inferred_shape); + const Shape& inferred_shape, + bool only_compare_minor_to_major_in_layout = false); // Overload which takes a StatusOr to reduce boilerplate in the caller. Status CheckShape(const HloInstruction* instruction, @@ -120,14 +121,31 @@ class ShapeVerifier : public DfsHloVisitor { private: // Helpers that switch on layout_sensitive_. - bool ShapesSame(const Shape& a, const Shape& b) { - return layout_sensitive_ ? ShapeUtil::Equal(a, b) - : ShapeUtil::Compatible(a, b); + bool ShapesSame(const Shape& a, const Shape& b, + bool minor_to_major_only = false) { + if (!layout_sensitive_) { + return ShapeUtil::Compatible(a, b); + } + Shape::Equal equal; + if (minor_to_major_only) { + equal.MinorToMajorOnlyInLayout(); + } + return equal(a, b); } - bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { - return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) - : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, + bool minor_to_major_only = false) { + if (!layout_sensitive_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + Shape::Equal equal; + if (minor_to_major_only) { + equal.MinorToMajorOnlyInLayout(); + } + equal.IgnoreFpPrecision(); + return equal(a, b); } + string StringifyShape(const Shape& s) { return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) : ShapeUtil::HumanString(s); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 1755767928d..704f12b5e87 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -173,7 +173,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, auto iter = buffer_constraints_.find(&buffer); if (iter != buffer_constraints_.end()) { const BufferLayoutConstraint& curr_constraint = iter->second; - if (LayoutUtil::Equal(curr_constraint.layout(), layout)) { + if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } @@ -210,7 +210,7 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, GetOperandLayoutConstraint(instruction, operand_no); if (curr_shape_layout != nullptr) { if (curr_shape_layout->shape_layout().MatchesLayoutInShape( - shape_with_layout)) { + shape_with_layout, /*minor_to_major_only=*/true)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } @@ -269,7 +269,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, const ShapeLayout* curr_shape_layout = ResultLayout(); if (curr_shape_layout != nullptr) { - if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { + if (!curr_shape_layout->MatchesLayoutInShape( + shape_with_layout, /*minor_to_major_only=*/true)) { return FailedPrecondition( "Result of computation %s already has the layout constraint %s, " "cannot add incompatible constraint %s", @@ -647,6 +648,10 @@ Status LayoutAssignment::AddMandatoryConstraints( namespace { +bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { + return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout()); +} + // The operands of a call must match the layouts of parameters in the // ComputationLayout, and the call instruction itself must match the result // layout in the ComputationLayout. @@ -656,10 +661,10 @@ Status CheckCallLayout(HloInstruction* call, TF_RET_CHECK(computation->num_parameters() == call->operand_count()); for (int64 i = 0; i < computation->num_parameters(); ++i) { TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( - call->operand(i)->shape())); + call->operand(i)->shape(), /*minor_to_major_only=*/true)); } - TF_RET_CHECK( - computation_layout.result_layout().MatchesLayoutInShape(call->shape())); + TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape( + call->shape(), /*minor_to_major_only=*/true)); return Status::OK(); } @@ -670,9 +675,9 @@ Status CheckCustomCallLayout(HloInstruction* instruction) { const HloCustomCallInstruction* custom_call = DynCast(instruction); for (int64 i = 0; i < custom_call->operand_count(); ++i) { - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - custom_call->operand(i)->shape(), - custom_call->operand_shapes_with_layout()[i])); + TF_RET_CHECK( + LayoutsInShapesEqual(custom_call->operand(i)->shape(), + custom_call->operand_shapes_with_layout()[i])); } } return Status::OK(); @@ -690,13 +695,12 @@ Status CheckWhileLayout(HloInstruction* while_inst, auto init_shape = while_inst->operand(0)->shape(); TF_RET_CHECK( condition_computation_layout.parameter_layout(0).MatchesLayoutInShape( - init_shape)); + init_shape, /*minor_to_major_only=*/true)); TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape( - init_shape)); - TF_RET_CHECK( - body_computation_layout.result_layout().MatchesLayoutInShape(init_shape)); - TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape())); + init_shape, /*minor_to_major_only=*/true)); + TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape( + init_shape, /*minor_to_major_only=*/true)); + TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape())); return Status::OK(); } @@ -709,13 +713,14 @@ Status CheckConditionalLayout( branch_computation_layouts[j].result_layout()); TF_RET_CHECK( branch_computation_layouts[j].result_layout().MatchesLayoutInShape( - instruction->shape())); + instruction->shape(), /*minor_to_major_only=*/true)); TF_RET_CHECK( branch_computation_layouts[j].result_layout().MatchesLayoutInShape( - instruction->branch_computation(j)->root_instruction()->shape())); + instruction->branch_computation(j)->root_instruction()->shape(), + /*minor_to_major_only=*/true)); TF_RET_CHECK( branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape( - branch_operand->shape())); + branch_operand->shape(), /*minor_to_major_only=*/true)); } return Status::OK(); } @@ -726,11 +731,11 @@ Status CheckConditionalLayout( Status CheckFusionLayout(HloInstruction* fusion) { TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode()); - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - fusion->shape(), fusion->fused_expression_root()->shape())); + TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(), + fusion->fused_expression_root()->shape())); for (int64 i = 0; i < fusion->operand_count(); ++i) { - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape())); + TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(), + fusion->operand(i)->shape())); } return Status::OK(); } @@ -742,7 +747,8 @@ Status CheckParameterLayout(HloInstruction* parameter, const ShapeLayout& parameter_layout = computation_layout.parameter_layout(parameter->parameter_number()); if (parameter_layout.LayoutIsSet() && - !parameter_layout.MatchesLayoutInShape(parameter->shape())) { + !parameter_layout.MatchesLayoutInShape(parameter->shape(), + /*minor_to_major_only=*/true)) { return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", @@ -753,8 +759,7 @@ Status CheckParameterLayout(HloInstruction* parameter, // The layout of a constant instruction must match the layout of its literal. Status CheckConstantLayout(HloInstruction* constant) { - if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(), - constant->shape())) { + if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) { return InternalError( "constant instruction %s does not match the layout of its literal %s", constant->ToString(), @@ -785,7 +790,8 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction* gte = instruction->parent()->AddInstruction( HloInstruction::CreateGetTupleElement(instr_shape, instruction, i)); - if (ShapeUtil::Equal(target_shape, instr_shape)) { + if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape, + instr_shape)) { // Shapes and layouts are equal, no need to copy. element_copies.push_back(gte); } else { @@ -831,7 +837,8 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( TF_RET_CHECK(operand_layout.LayoutIsSet()); TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); - if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(), + operand->shape())) { VLOG(5) << "Operand " << operand->ToString() << " layout matches in " << instruction->ToString(); // Operand layout already matches our constraint. Nothing to do. @@ -892,7 +899,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { const Shape& instruction_subshape = ShapeUtil::GetSubshape(instruction->shape(), index); for (const LogicalBuffer* buffer : buffers) { - if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) { + if (!Shape::Equal().MinorToMajorOnlyInLayout()( + instruction_subshape, buffer->shape())) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -954,8 +962,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, module->entry_computation()) .result_layout(); if (result_layout.LayoutIsSet()) { - TF_RET_CHECK( - ShapeUtil::Equal(module->result_shape(), result_layout.shape())); + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + module->result_shape(), result_layout.shape())); } return Status::OK(); } @@ -1510,8 +1518,8 @@ StatusOr InferArrayLayout( if (first_buffer_layout == nullptr) { first_buffer_layout = &source_buffer->shape().layout(); - } else if (!LayoutUtil::Equal(source_buffer->shape().layout(), - *first_buffer_layout)) { + } else if (!Layout::Equal().MinorToMajorOnly()( + source_buffer->shape().layout(), *first_buffer_layout)) { // The points-to set is ambiguous for this index and the different source // buffers have different layouts. This case is possible in valid XLA // computations because we do not propagate BufferLayoutConstraints to all @@ -1789,7 +1797,8 @@ Status LayoutAssignment::RunOnComputation( // layout constraint. if (constraints.ResultLayout() != nullptr && !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { + computation->root_instruction()->shape(), + /*minor_to_major_only=*/true)) { if (conditional_mismatch_.count(computation) > 0) { *FindOrDie(computation_layouts_, computation).mutable_result_layout() = FindOrDie(conditional_mismatch_, computation).result_layout(); @@ -1907,7 +1916,9 @@ Status LayoutAssignment::PropagateComputationLayouts( << ": " << computed_computation_layout.result_layout().ToString(); *result_layout = computed_computation_layout.result_layout(); } else { - TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + computed_computation_layout.result_layout().shape(), + result_layout->shape())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 94854047e53..27d24514f8f 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -141,6 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { } } + if (!ShapeUtil::SameDimensions(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } + if (!ignore_layout_) { if (lhs.layout().format() != rhs.layout().format()) { VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; @@ -161,11 +166,6 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { } } - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; - return false; - } - if (!ignore_dynamic_dimension_) { for (int i = 0; i < lhs.rank(); ++i) { if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 78cea83c6d7..0b8530dd929 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -169,6 +169,11 @@ class Shape { ignore_element_size_in_layout_ = true; return *this; } + Equal& MinorToMajorOnlyInLayout() { + ignore_tiles_in_layout_ = true; + ignore_element_size_in_layout_ = true; + return *this; + } Equal& IgnoreElementType() { ignore_element_type_ = true; return *this; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index a000886d60d..44ed3181162 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -46,8 +46,13 @@ void ShapeLayout::SetToDefaultLayout() { LayoutUtil::SetToDefaultLayout(&shape_); } -bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const { - return ShapeUtil::Equal(shape, shape_); +bool ShapeLayout::MatchesLayoutInShape(const Shape& shape, + bool minor_to_major_only) const { + auto equal = Shape::Equal(); + if (minor_to_major_only) { + equal.MinorToMajorOnlyInLayout(); + } + return equal(shape, shape_); } const Layout& ShapeLayout::layout() const { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 214cf988549..b4982f1d8e4 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -45,7 +45,8 @@ class ShapeLayout { // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. - bool MatchesLayoutInShape(const Shape& shape) const; + bool MatchesLayoutInShape(const Shape& shape, + bool minor_to_major_only = false) const; // Copies the layout from the given shape into this ShapeLayout. 'other_shape' // must be compatible with the ShapeLayout's shape. diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index acaa9cae7c2..de3b58ff46c 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -102,6 +102,11 @@ StatusOr MakeShapeWithLayoutInternal( } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); + if (element_size_in_bits == + ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) { + // Only set element_size_in_bits if it's different from the default value. + element_size_in_bits = 0; + } *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits); if (!shape.has_layout()) { @@ -219,7 +224,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( for (int i = 0; i < shape.dimensions_size(); ++i) { dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - return MakeShapeWithDescendingLayout(shape.element_type(), dims); + Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims); + // Since the physical layout is kept the same, the tiles and element size are + // the same also. + *new_shape.mutable_layout()->mutable_tiles() = shape.layout().tiles(); + new_shape.mutable_layout()->set_element_size_in_bits( + shape.layout().element_size_in_bits()); + return new_shape; } /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type, diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 414d0b14a6b..c4cb5aaaeb1 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -395,6 +395,8 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { ParametricDotTestWithoutLayoutAssignment() { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "layout-assignment"); + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "tiling-assignment"); // Disable algebraic simplification because the pass may replace a dot // instruction with a layout-changing multiplication instruction. execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 1111f824051..f9e21d4db2e 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -846,16 +846,16 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { xla::ProgramShape xla_program_shape = XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); - EXPECT_TRUE(xla::LayoutUtil::Equal( + EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) .layout())); - EXPECT_TRUE(xla::LayoutUtil::Equal( + EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(), xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1}) .layout())); - EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(), - xla_program_shape.result().layout())); + EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( + program_shape.result().layout(), xla_program_shape.result().layout())); } TEST(RawApiTest, DotGeneralWithLayoutTest) {