From 74770749840e1c823a50b743a50637afc3529e3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 29 Apr 2017 12:03:51 -0800 Subject: [PATCH] [XLA] Make ReshapeMover account for broadcast operands, add VLOGging for debug. Change: 154637127 --- tensorflow/compiler/xla/service/BUILD | 2 + .../compiler/xla/service/hlo_instruction.cc | 12 +- .../compiler/xla/service/hlo_instruction.h | 5 +- .../compiler/xla/service/reshape_mover.cc | 274 +++++++++++------- .../xla/service/reshape_mover_test.cc | 52 ++++ 5 files changed, 228 insertions(+), 117 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 21378887266..05fc480936f 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -859,7 +859,9 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d8e01f88b9f..179e1832654 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -410,7 +410,9 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateReshape( const Shape& shape, HloInstruction* operand) { CHECK_EQ(ShapeUtil::ElementsIn(shape), - ShapeUtil::ElementsIn(operand->shape())); + ShapeUtil::ElementsIn(operand->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operand->shape()); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; @@ -1428,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const { return opc_name; } -string HloInstruction::ToString(bool compact_operands) const { +string HloInstruction::ToString(bool compact_operands, + bool include_metadata) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. @@ -1509,8 +1512,9 @@ string HloInstruction::ToString(bool compact_operands) const { if (opcode() == HloOpcode::kGetTupleElement) { StrAppend(&extra, ", index=", tuple_index()); } - if (!metadata_.op_type().empty() || !metadata_.op_name().empty() || - !metadata_.source_file().empty()) { + if (include_metadata && + (!metadata_.op_type().empty() || !metadata_.op_name().empty() || + !metadata_.source_file().empty())) { StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 1c7d4c19b97..5ec17c80048 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -489,7 +489,10 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false) const; + string ToString(bool compact_operands = false, + bool include_metadata = true) const; + + string ToStringNoMetadata() const { return ToString(false, false); } // As ToString, but returns a shorter string. string ToShortString() const; diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index b72ef95a6a7..768977ba6bb 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -13,17 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/reshape_mover.h" - -#include -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -namespace { - +// Implementation note: +// // The general idea behind this pass is that we're converting from this: // %param.A = OldShape // %param.B = OldShape @@ -44,6 +35,19 @@ namespace { // only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or // transposes to a scalar should be cheap, we simply never move them. +#include "tensorflow/compiler/xla/service/reshape_mover.h" + +#include +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + // Finds the first non-scalar operand of an instruction that is a reshape or // transpose and returns the operand if it is found or nullptr if not found. HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { @@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { if (!ShapeUtil::IsScalar(operand->shape()) && (operand->opcode() == HloOpcode::kReshape || operand->opcode() == HloOpcode::kTranspose)) { + VLOG(5) << "Found first non-scalar reshape operand of " + << hlo->ToStringNoMetadata() << ":\n\t" + << operand->ToStringNoMetadata(); return operand; } } @@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction, // A constant can trivially reshape the literal it holds. if (operand->opcode() == HloOpcode::kConstant && ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); return true; } @@ -116,119 +126,159 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( if (!first_reshape_operand) { return false; } - return (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty() && - // Check whether all operands: - // 1. are all reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. can be any shape like kConstant, kRng, and scalars. - std::all_of( - operands.begin(), operands.end(), - [instruction, - first_reshape_operand](const HloInstruction* operand) { - return AreEquivalentReshapes(first_reshape_operand, operand) || - OperandCanTrivallyChangeShape(instruction, operand); - }); + VLOG(3) << "** Checking whether instruction is an elementwise operation of " + "equivalent reshapes/transposes: " + << instruction->ToStringNoMetadata(); + bool result = + (instruction->user_count() > 0 || + instruction == instruction->parent()->root_instruction()) && + instruction->IsElementwise() && !operands.empty() && + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, it may be + // implicitly broadcast, which can confound the movement's + // correctness. + // 1. Are all reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Can be any shape like kConstant, kRng, and scalars. + std::all_of( + operands.begin(), operands.end(), + [instruction, first_reshape_operand](const HloInstruction* operand) { + if (!ShapeUtil::SameDimensions(operand->shape(), + instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() << "\n\tinstruction: " + << instruction->ToStringNoMetadata(); + return false; + } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + return true; + } + if (OperandCanTrivallyChangeShape(instruction, operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToStringNoMetadata(); + return true; + } + return false; + }); + VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " + << instruction->ToStringNoMetadata() << ": " << result; + return result; } // Try to sink any reshape or transpose operands of `instruction` across it. We // do so if `instruction` is elementwise and all operands are equivalent // reshapes or transposes. -bool TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { - std::vector operands = instruction->operands(); - HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); - CHECK(old_reshape != nullptr); - Shape new_elementwise_shape = old_reshape->operand(0)->shape(); - for (size_t i = 0; i < operands.size(); ++i) { - // All scalar operands remain as-is, even if they're reshape or transpose, - // to simplify handling wrt special scalar broadcast rules for ops like - // Select. Scalar reshapes should be cheap anyways. - if (ShapeUtil::IsScalar(operands[i]->shape())) { - continue; - } - auto element_type = operands[i]->shape().element_type(); - switch (operands[i]->opcode()) { - case HloOpcode::kConstant: { - if (old_reshape->opcode() == HloOpcode::kReshape) { - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i])); - } else { - CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose); - std::vector inverse_permutation = - InversePermutation(old_reshape->dimensions()); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateTranspose( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i], inverse_permutation)); - } - break; - } - case HloOpcode::kRng: { - CHECK_EQ(operands[i]->user_count(), 1); +StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, + HloInstruction* instruction) { + if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + return false; + } + + std::vector operands = instruction->operands(); + HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); + TF_RET_CHECK(old_reshape != nullptr); + Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + + VLOG(3) << "** Trying to sink reshape or transpose: " + << instruction->ToStringNoMetadata() + << "\n\told reshape: " << old_reshape->ToStringNoMetadata() + << "\n\tnew elementwise shape: " + << ShapeUtil::HumanString(new_elementwise_shape); + for (size_t i = 0; i < operands.size(); ++i) { + // All scalar operands remain as-is, even if they're reshape or transpose, + // to simplify handling wrt special scalar broadcast rules for ops like + // Select. Scalar reshapes should be cheap anyways. + if (ShapeUtil::IsScalar(operands[i]->shape())) { + continue; + } + PrimitiveType element_type = operands[i]->shape().element_type(); + switch (operands[i]->opcode()) { + case HloOpcode::kConstant: { + if (old_reshape->opcode() == HloOpcode::kReshape) { + VLOG(3) << "Creating reshape for kConstant operand " << i << ": " + << operands[i]->ToStringNoMetadata(); operands[i] = instruction->parent()->AddInstruction( - operands[i]->CloneWithNewOperands( + HloInstruction::CreateReshape( ShapeUtil::ChangeElementType(new_elementwise_shape, element_type), - operands[i]->operands())); - break; + operands[i])); + } else { + TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); + std::vector inverse_permutation = + InversePermutation(old_reshape->dimensions()); + operands[i] = instruction->parent()->AddInstruction( + HloInstruction::CreateTranspose( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i], inverse_permutation)); } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: - operands[i] = operands[i]->mutable_operand(0); - break; - default: - LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " - "transposes."; - } - } - if (HloOpcode::kFusion == instruction->opcode()) { - // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not be equivalent - // reshapes, so all the fused instructions have the same dimensions. - for (const auto& fused_instruction : instruction->fused_instructions()) { - Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); - *shape->mutable_layout() = new_elementwise_shape.layout(); - } - } - auto new_elementwise = - computation->AddInstruction(instruction->CloneWithNewOperands( - // `instruction` may change the element type, e.g., from - // operands[0] -> reshape -> convert (`instruction`) - // to - // operands[0] -> convert' -> reshape' - // - // In this case, convert' should have the same element type as - // `convert` and the same dimensions as operands[0]. - ShapeUtil::ChangeElementType(new_elementwise_shape, - instruction->shape().element_type()), - operands)); - std::unique_ptr new_reshape; - switch (old_reshape->opcode()) { - case HloOpcode::kReshape: - new_reshape = HloInstruction::CreateReshape(instruction->shape(), - new_elementwise); break; + } + case HloOpcode::kRng: { + CHECK_EQ(operands[i]->user_count(), 1); + operands[i] = instruction->parent()->AddInstruction( + operands[i]->CloneWithNewOperands( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i]->operands())); + break; + } + case HloOpcode::kReshape: case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + operands[i] = operands[i]->mutable_operand(0); break; default: - LOG(FATAL) << "Bad opcode"; + LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " + "transposes."; } - TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction, - std::move(new_reshape))); - return true; } - return false; + if (HloOpcode::kFusion == instruction->opcode()) { + // Here we already know `instruction` is elementwise, and no operand is + // implicit broadcast as if it were the operands would not be equivalent + // reshapes, so all the fused instructions have the same dimensions. + for (const auto& fused_instruction : instruction->fused_instructions()) { + Shape* shape = fused_instruction->mutable_shape(); + *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); + *shape->mutable_layout() = new_elementwise_shape.layout(); + } + } + HloInstruction* new_elementwise = + computation->AddInstruction(instruction->CloneWithNewOperands( + // `instruction` may change the element type, e.g., from + // operands[0] -> reshape -> convert (`instruction`) + // to + // operands[0] -> convert' -> reshape' + // + // In this case, convert' should have the same element type as + // `convert` and the same dimensions as operands[0]. + ShapeUtil::ChangeElementType(new_elementwise_shape, + instruction->shape().element_type()), + operands)); + + std::unique_ptr new_reshape; + switch (old_reshape->opcode()) { + case HloOpcode::kReshape: + VLOG(3) << "Creating new reshape for new elementwise op: " + << new_elementwise->ToStringNoMetadata(); + new_reshape = + HloInstruction::CreateReshape(instruction->shape(), new_elementwise); + break; + case HloOpcode::kTranspose: + new_reshape = HloInstruction::CreateTranspose( + instruction->shape(), new_elementwise, old_reshape->dimensions()); + break; + default: + LOG(FATAL) << "Bad opcode"; + } + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, std::move(new_reshape))); + return true; } } // namespace @@ -237,9 +287,9 @@ StatusOr ReshapeMover::Run(HloModule* module) { bool changed = false; for (const auto& comp : module->computations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - if (TrySinkReshapeOrTranspose(comp.get(), instruction)) { - changed = true; - } + TF_ASSIGN_OR_RETURN(bool did_change, + TrySinkReshapeOrTranspose(comp.get(), instruction)); + changed |= did_change; } } return changed; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 1831d775d4a..5217e85d4fc 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -234,6 +234,58 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { EXPECT_EQ(select, computation->root_instruction()); } +// Tree looks like: +// +// param0 [1,128,1] +// | +// reshape [128,1] constant [128,1024] +// \ / +// multiply w/implicit broadcast [128,1024] +// +// The reshape mover would like to sink the reshape below the multiply. +// +// Previously we would attempt to insert a reshape of the constant to [1,128,1] +// (which is unsound, because it has a different number of elements) as +// preparation for sinking the reshape. +// +// To eliminate the unsoundness, we outlaw reshape sinking when one of the +// operands is implicitly broadcast in the elementwise consumer. +// +// TODO(b/37799338) However, it would be possible in this case to do a more +// in-depth analysis to get reshape movement to occur: +// +// 1. Note that the broadcast dimension (logical dimension 1) in the operands +// would map back to logical dimension 2 in the param0 node. +// 2. Match rank of the constant to the param0 node (by prepending a trivial 1 +// dimension). +// 3. Reshape to [128,1024] at the root. +// +// But this is not currently done. +TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); + auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {128, 1}), param0)); + Array2D a(128, 1024); + auto literal = LiteralUtil::CreateR2FromArray2D(a); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kMultiply, constant, reshape)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + EXPECT_EQ(multiply, computation->root_instruction()); +} + // Tree looks like this: // // add1