[XLA] Make ReshapeMover account for broadcast operands, add VLOGging for debug.

Change: 154637127
This commit is contained in:
A. Unique TensorFlower 2017-04-29 12:03:51 -08:00 committed by TensorFlower Gardener
parent a25509eda3
commit 7477074984
5 changed files with 228 additions and 117 deletions

View File

@ -859,7 +859,9 @@ cc_library(
":hlo_pass", ":hlo_pass",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
], ],
) )

View File

@ -410,7 +410,9 @@ HloInstruction::CreateSelectAndScatter(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
const Shape& shape, HloInstruction* operand) { const Shape& shape, HloInstruction* operand) {
CHECK_EQ(ShapeUtil::ElementsIn(shape), CHECK_EQ(ShapeUtil::ElementsIn(shape),
ShapeUtil::ElementsIn(operand->shape())); ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
instruction->AppendOperand(operand); instruction->AppendOperand(operand);
return instruction; return instruction;
@ -1428,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const {
return opc_name; return opc_name;
} }
string HloInstruction::ToString(bool compact_operands) const { string HloInstruction::ToString(bool compact_operands,
bool include_metadata) const {
string operands; string operands;
if (opcode() == HloOpcode::kConstant) { if (opcode() == HloOpcode::kConstant) {
// For constants, show the actual value in place of an empty operand list. // For constants, show the actual value in place of an empty operand list.
@ -1509,8 +1512,9 @@ string HloInstruction::ToString(bool compact_operands) const {
if (opcode() == HloOpcode::kGetTupleElement) { if (opcode() == HloOpcode::kGetTupleElement) {
StrAppend(&extra, ", index=", tuple_index()); StrAppend(&extra, ", index=", tuple_index());
} }
if (!metadata_.op_type().empty() || !metadata_.op_name().empty() || if (include_metadata &&
!metadata_.source_file().empty()) { (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
} }

View File

@ -489,7 +489,10 @@ class HloInstruction {
string SignatureString() const; string SignatureString() const;
// Returns a debugging string that represents this instruction. // Returns a debugging string that represents this instruction.
string ToString(bool compact_operands = false) const; string ToString(bool compact_operands = false,
bool include_metadata = true) const;
string ToStringNoMetadata() const { return ToString(false, false); }
// As ToString, but returns a shorter string. // As ToString, but returns a shorter string.
string ToShortString() const; string ToShortString() const;

View File

@ -13,17 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/service/reshape_mover.h" // Implementation note:
//
#include <algorithm>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace {
// The general idea behind this pass is that we're converting from this: // The general idea behind this pass is that we're converting from this:
// %param.A = OldShape // %param.A = OldShape
// %param.B = OldShape // %param.B = OldShape
@ -44,6 +35,19 @@ namespace {
// only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or // only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or
// transposes to a scalar should be cheap, we simply never move them. // transposes to a scalar should be cheap, we simply never move them.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace {
// Finds the first non-scalar operand of an instruction that is a reshape or // Finds the first non-scalar operand of an instruction that is a reshape or
// transpose and returns the operand if it is found or nullptr if not found. // transpose and returns the operand if it is found or nullptr if not found.
HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
if (!ShapeUtil::IsScalar(operand->shape()) && if (!ShapeUtil::IsScalar(operand->shape()) &&
(operand->opcode() == HloOpcode::kReshape || (operand->opcode() == HloOpcode::kReshape ||
operand->opcode() == HloOpcode::kTranspose)) { operand->opcode() == HloOpcode::kTranspose)) {
VLOG(5) << "Found first non-scalar reshape operand of "
<< hlo->ToStringNoMetadata() << ":\n\t"
<< operand->ToStringNoMetadata();
return operand; return operand;
} }
} }
@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction,
// A constant can trivially reshape the literal it holds. // A constant can trivially reshape the literal it holds.
if (operand->opcode() == HloOpcode::kConstant && if (operand->opcode() == HloOpcode::kConstant &&
ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: "
<< operand->ToStringNoMetadata()
<< "\n\tinstruction: " << instruction->ToStringNoMetadata();
return true; return true;
} }
@ -116,119 +126,159 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes(
if (!first_reshape_operand) { if (!first_reshape_operand) {
return false; return false;
} }
return (instruction->user_count() > 0 || VLOG(3) << "** Checking whether instruction is an elementwise operation of "
instruction == instruction->parent()->root_instruction()) && "equivalent reshapes/transposes: "
instruction->IsElementwise() && !operands.empty() && << instruction->ToStringNoMetadata();
// Check whether all operands: bool result =
// 1. are all reshapes or transposes that have the same input and (instruction->user_count() > 0 ||
// output shapes as all other reshaped or transposed operands. instruction == instruction->parent()->root_instruction()) &&
// or instruction->IsElementwise() && !operands.empty() &&
// 2. can be any shape like kConstant, kRng, and scalars. // Check whether all operands:
std::all_of( // 0. Have the same dimensions as the output -- if not, it may be
operands.begin(), operands.end(), // implicitly broadcast, which can confound the movement's
[instruction, // correctness.
first_reshape_operand](const HloInstruction* operand) { // 1. Are all reshapes or transposes that have the same input and
return AreEquivalentReshapes(first_reshape_operand, operand) || // output shapes as all other reshaped or transposed operands.
OperandCanTrivallyChangeShape(instruction, operand); // or
}); // 2. Can be any shape like kConstant, kRng, and scalars.
std::all_of(
operands.begin(), operands.end(),
[instruction, first_reshape_operand](const HloInstruction* operand) {
if (!ShapeUtil::SameDimensions(operand->shape(),
instruction->shape())) {
VLOG(5) << "Operand shape differs from output shape; may be "
"implicitly broadcast, so preventing "
"movement\n\toperand: "
<< operand->ToStringNoMetadata() << "\n\tinstruction: "
<< instruction->ToStringNoMetadata();
return false;
}
if (AreEquivalentReshapes(first_reshape_operand, operand)) {
VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
<< first_reshape_operand->ToStringNoMetadata()
<< "\n\toperand: " << operand->ToStringNoMetadata();
return true;
}
if (OperandCanTrivallyChangeShape(instruction, operand)) {
VLOG(5) << "Operand can trivially change shape: "
<< operand->ToStringNoMetadata();
return true;
}
return false;
});
VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for "
<< instruction->ToStringNoMetadata() << ": " << result;
return result;
} }
// Try to sink any reshape or transpose operands of `instruction` across it. We // Try to sink any reshape or transpose operands of `instruction` across it. We
// do so if `instruction` is elementwise and all operands are equivalent // do so if `instruction` is elementwise and all operands are equivalent
// reshapes or transposes. // reshapes or transposes.
bool TrySinkReshapeOrTranspose(HloComputation* computation, StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
HloInstruction* instruction) { HloInstruction* instruction) {
if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
std::vector<HloInstruction*> operands = instruction->operands(); return false;
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); }
CHECK(old_reshape != nullptr);
Shape new_elementwise_shape = old_reshape->operand(0)->shape(); std::vector<HloInstruction*> operands = instruction->operands();
for (size_t i = 0; i < operands.size(); ++i) { HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
// All scalar operands remain as-is, even if they're reshape or transpose, TF_RET_CHECK(old_reshape != nullptr);
// to simplify handling wrt special scalar broadcast rules for ops like Shape new_elementwise_shape = old_reshape->operand(0)->shape();
// Select. Scalar reshapes should be cheap anyways.
if (ShapeUtil::IsScalar(operands[i]->shape())) { VLOG(3) << "** Trying to sink reshape or transpose: "
continue; << instruction->ToStringNoMetadata()
} << "\n\told reshape: " << old_reshape->ToStringNoMetadata()
auto element_type = operands[i]->shape().element_type(); << "\n\tnew elementwise shape: "
switch (operands[i]->opcode()) { << ShapeUtil::HumanString(new_elementwise_shape);
case HloOpcode::kConstant: { for (size_t i = 0; i < operands.size(); ++i) {
if (old_reshape->opcode() == HloOpcode::kReshape) { // All scalar operands remain as-is, even if they're reshape or transpose,
operands[i] = instruction->parent()->AddInstruction( // to simplify handling wrt special scalar broadcast rules for ops like
HloInstruction::CreateReshape( // Select. Scalar reshapes should be cheap anyways.
ShapeUtil::ChangeElementType(new_elementwise_shape, if (ShapeUtil::IsScalar(operands[i]->shape())) {
element_type), continue;
operands[i])); }
} else { PrimitiveType element_type = operands[i]->shape().element_type();
CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose); switch (operands[i]->opcode()) {
std::vector<int64> inverse_permutation = case HloOpcode::kConstant: {
InversePermutation(old_reshape->dimensions()); if (old_reshape->opcode() == HloOpcode::kReshape) {
operands[i] = instruction->parent()->AddInstruction( VLOG(3) << "Creating reshape for kConstant operand " << i << ": "
HloInstruction::CreateTranspose( << operands[i]->ToStringNoMetadata();
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i], inverse_permutation));
}
break;
}
case HloOpcode::kRng: {
CHECK_EQ(operands[i]->user_count(), 1);
operands[i] = instruction->parent()->AddInstruction( operands[i] = instruction->parent()->AddInstruction(
operands[i]->CloneWithNewOperands( HloInstruction::CreateReshape(
ShapeUtil::ChangeElementType(new_elementwise_shape, ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type), element_type),
operands[i]->operands())); operands[i]));
break; } else {
TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose);
std::vector<int64> inverse_permutation =
InversePermutation(old_reshape->dimensions());
operands[i] = instruction->parent()->AddInstruction(
HloInstruction::CreateTranspose(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i], inverse_permutation));
} }
case HloOpcode::kReshape:
case HloOpcode::kTranspose:
operands[i] = operands[i]->mutable_operand(0);
break;
default:
LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
"transposes.";
}
}
if (HloOpcode::kFusion == instruction->opcode()) {
// Here we already know `instruction` is elementwise, and no operand is
// implicit broadcast as if it were the operands would not be equivalent
// reshapes, so all the fused instructions have the same dimensions.
for (const auto& fused_instruction : instruction->fused_instructions()) {
Shape* shape = fused_instruction->mutable_shape();
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
*shape->mutable_layout() = new_elementwise_shape.layout();
}
}
auto new_elementwise =
computation->AddInstruction(instruction->CloneWithNewOperands(
// `instruction` may change the element type, e.g., from
// operands[0] -> reshape -> convert (`instruction`)
// to
// operands[0] -> convert' -> reshape'
//
// In this case, convert' should have the same element type as
// `convert` and the same dimensions as operands[0].
ShapeUtil::ChangeElementType(new_elementwise_shape,
instruction->shape().element_type()),
operands));
std::unique_ptr<HloInstruction> new_reshape;
switch (old_reshape->opcode()) {
case HloOpcode::kReshape:
new_reshape = HloInstruction::CreateReshape(instruction->shape(),
new_elementwise);
break; break;
}
case HloOpcode::kRng: {
CHECK_EQ(operands[i]->user_count(), 1);
operands[i] = instruction->parent()->AddInstruction(
operands[i]->CloneWithNewOperands(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i]->operands()));
break;
}
case HloOpcode::kReshape:
case HloOpcode::kTranspose: case HloOpcode::kTranspose:
new_reshape = HloInstruction::CreateTranspose( operands[i] = operands[i]->mutable_operand(0);
instruction->shape(), new_elementwise, old_reshape->dimensions());
break; break;
default: default:
LOG(FATAL) << "Bad opcode"; LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
"transposes.";
} }
TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction,
std::move(new_reshape)));
return true;
} }
return false; if (HloOpcode::kFusion == instruction->opcode()) {
// Here we already know `instruction` is elementwise, and no operand is
// implicit broadcast as if it were the operands would not be equivalent
// reshapes, so all the fused instructions have the same dimensions.
for (const auto& fused_instruction : instruction->fused_instructions()) {
Shape* shape = fused_instruction->mutable_shape();
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
*shape->mutable_layout() = new_elementwise_shape.layout();
}
}
HloInstruction* new_elementwise =
computation->AddInstruction(instruction->CloneWithNewOperands(
// `instruction` may change the element type, e.g., from
// operands[0] -> reshape -> convert (`instruction`)
// to
// operands[0] -> convert' -> reshape'
//
// In this case, convert' should have the same element type as
// `convert` and the same dimensions as operands[0].
ShapeUtil::ChangeElementType(new_elementwise_shape,
instruction->shape().element_type()),
operands));
std::unique_ptr<HloInstruction> new_reshape;
switch (old_reshape->opcode()) {
case HloOpcode::kReshape:
VLOG(3) << "Creating new reshape for new elementwise op: "
<< new_elementwise->ToStringNoMetadata();
new_reshape =
HloInstruction::CreateReshape(instruction->shape(), new_elementwise);
break;
case HloOpcode::kTranspose:
new_reshape = HloInstruction::CreateTranspose(
instruction->shape(), new_elementwise, old_reshape->dimensions());
break;
default:
LOG(FATAL) << "Bad opcode";
}
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, std::move(new_reshape)));
return true;
} }
} // namespace } // namespace
@ -237,9 +287,9 @@ StatusOr<bool> ReshapeMover::Run(HloModule* module) {
bool changed = false; bool changed = false;
for (const auto& comp : module->computations()) { for (const auto& comp : module->computations()) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (TrySinkReshapeOrTranspose(comp.get(), instruction)) { TF_ASSIGN_OR_RETURN(bool did_change,
changed = true; TrySinkReshapeOrTranspose(comp.get(), instruction));
} changed |= did_change;
} }
} }
return changed; return changed;

View File

@ -234,6 +234,58 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
EXPECT_EQ(select, computation->root_instruction()); EXPECT_EQ(select, computation->root_instruction());
} }
// Tree looks like:
//
// param0 [1,128,1]
// |
// reshape [128,1] constant [128,1024]
// \ /
// multiply w/implicit broadcast [128,1024]
//
// The reshape mover would like to sink the reshape below the multiply.
//
// Previously we would attempt to insert a reshape of the constant to [1,128,1]
// (which is unsound, because it has a different number of elements) as
// preparation for sinking the reshape.
//
// To eliminate the unsoundness, we outlaw reshape sinking when one of the
// operands is implicitly broadcast in the elementwise consumer.
//
// TODO(b/37799338) However, it would be possible in this case to do a more
// in-depth analysis to get reshape movement to occur:
//
// 1. Note that the broadcast dimension (logical dimension 1) in the operands
// would map back to logical dimension 2 in the param0 node.
// 2. Match rank of the constant to the param0 node (by prepending a trivial 1
// dimension).
// 3. Reshape to [128,1024] at the root.
//
// But this is not currently done.
TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0"));
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {128, 1}), param0));
Array2D<float> a(128, 1024);
auto literal = LiteralUtil::CreateR2FromArray2D<float>(a);
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
auto multiply = builder.AddInstruction(HloInstruction::CreateBinary(
constant->shape(), HloOpcode::kMultiply, constant, reshape));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Constant(), op::Reshape(param0)));
EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Constant(), op::Reshape(param0)));
EXPECT_EQ(multiply, computation->root_instruction());
}
// Tree looks like this: // Tree looks like this:
// //
// add1 // add1