[XLA] Make ReshapeMover account for broadcast operands, add VLOGging for debug.

Change: 154637127
This commit is contained in:
A. Unique TensorFlower 2017-04-29 12:03:51 -08:00 committed by TensorFlower Gardener
parent a25509eda3
commit 7477074984
5 changed files with 228 additions and 117 deletions

View File

@ -859,7 +859,9 @@ cc_library(
":hlo_pass",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)

View File

@ -410,7 +410,9 @@ HloInstruction::CreateSelectAndScatter(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
const Shape& shape, HloInstruction* operand) {
CHECK_EQ(ShapeUtil::ElementsIn(shape),
ShapeUtil::ElementsIn(operand->shape()));
ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
instruction->AppendOperand(operand);
return instruction;
@ -1428,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const {
return opc_name;
}
string HloInstruction::ToString(bool compact_operands) const {
string HloInstruction::ToString(bool compact_operands,
bool include_metadata) const {
string operands;
if (opcode() == HloOpcode::kConstant) {
// For constants, show the actual value in place of an empty operand list.
@ -1509,8 +1512,9 @@ string HloInstruction::ToString(bool compact_operands) const {
if (opcode() == HloOpcode::kGetTupleElement) {
StrAppend(&extra, ", index=", tuple_index());
}
if (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty()) {
if (include_metadata &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
}

View File

@ -489,7 +489,10 @@ class HloInstruction {
string SignatureString() const;
// Returns a debugging string that represents this instruction.
string ToString(bool compact_operands = false) const;
string ToString(bool compact_operands = false,
bool include_metadata = true) const;
string ToStringNoMetadata() const { return ToString(false, false); }
// As ToString, but returns a shorter string.
string ToShortString() const;

View File

@ -13,17 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace {
// Implementation note:
//
// The general idea behind this pass is that we're converting from this:
// %param.A = OldShape
// %param.B = OldShape
@ -44,6 +35,19 @@ namespace {
// only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or
// transposes to a scalar should be cheap, we simply never move them.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace {
// Finds the first non-scalar operand of an instruction that is a reshape or
// transpose and returns the operand if it is found or nullptr if not found.
HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
if (!ShapeUtil::IsScalar(operand->shape()) &&
(operand->opcode() == HloOpcode::kReshape ||
operand->opcode() == HloOpcode::kTranspose)) {
VLOG(5) << "Found first non-scalar reshape operand of "
<< hlo->ToStringNoMetadata() << ":\n\t"
<< operand->ToStringNoMetadata();
return operand;
}
}
@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction,
// A constant can trivially reshape the literal it holds.
if (operand->opcode() == HloOpcode::kConstant &&
ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: "
<< operand->ToStringNoMetadata()
<< "\n\tinstruction: " << instruction->ToStringNoMetadata();
return true;
}
@ -116,119 +126,159 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes(
if (!first_reshape_operand) {
return false;
}
return (instruction->user_count() > 0 ||
instruction == instruction->parent()->root_instruction()) &&
instruction->IsElementwise() && !operands.empty() &&
// Check whether all operands:
// 1. are all reshapes or transposes that have the same input and
// output shapes as all other reshaped or transposed operands.
// or
// 2. can be any shape like kConstant, kRng, and scalars.
std::all_of(
operands.begin(), operands.end(),
[instruction,
first_reshape_operand](const HloInstruction* operand) {
return AreEquivalentReshapes(first_reshape_operand, operand) ||
OperandCanTrivallyChangeShape(instruction, operand);
});
VLOG(3) << "** Checking whether instruction is an elementwise operation of "
"equivalent reshapes/transposes: "
<< instruction->ToStringNoMetadata();
bool result =
(instruction->user_count() > 0 ||
instruction == instruction->parent()->root_instruction()) &&
instruction->IsElementwise() && !operands.empty() &&
// Check whether all operands:
// 0. Have the same dimensions as the output -- if not, it may be
// implicitly broadcast, which can confound the movement's
// correctness.
// 1. Are all reshapes or transposes that have the same input and
// output shapes as all other reshaped or transposed operands.
// or
// 2. Can be any shape like kConstant, kRng, and scalars.
std::all_of(
operands.begin(), operands.end(),
[instruction, first_reshape_operand](const HloInstruction* operand) {
if (!ShapeUtil::SameDimensions(operand->shape(),
instruction->shape())) {
VLOG(5) << "Operand shape differs from output shape; may be "
"implicitly broadcast, so preventing "
"movement\n\toperand: "
<< operand->ToStringNoMetadata() << "\n\tinstruction: "
<< instruction->ToStringNoMetadata();
return false;
}
if (AreEquivalentReshapes(first_reshape_operand, operand)) {
VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
<< first_reshape_operand->ToStringNoMetadata()
<< "\n\toperand: " << operand->ToStringNoMetadata();
return true;
}
if (OperandCanTrivallyChangeShape(instruction, operand)) {
VLOG(5) << "Operand can trivially change shape: "
<< operand->ToStringNoMetadata();
return true;
}
return false;
});
VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for "
<< instruction->ToStringNoMetadata() << ": " << result;
return result;
}
// Try to sink any reshape or transpose operands of `instruction` across it. We
// do so if `instruction` is elementwise and all operands are equivalent
// reshapes or transposes.
bool TrySinkReshapeOrTranspose(HloComputation* computation,
HloInstruction* instruction) {
if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
std::vector<HloInstruction*> operands = instruction->operands();
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
CHECK(old_reshape != nullptr);
Shape new_elementwise_shape = old_reshape->operand(0)->shape();
for (size_t i = 0; i < operands.size(); ++i) {
// All scalar operands remain as-is, even if they're reshape or transpose,
// to simplify handling wrt special scalar broadcast rules for ops like
// Select. Scalar reshapes should be cheap anyways.
if (ShapeUtil::IsScalar(operands[i]->shape())) {
continue;
}
auto element_type = operands[i]->shape().element_type();
switch (operands[i]->opcode()) {
case HloOpcode::kConstant: {
if (old_reshape->opcode() == HloOpcode::kReshape) {
operands[i] = instruction->parent()->AddInstruction(
HloInstruction::CreateReshape(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i]));
} else {
CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose);
std::vector<int64> inverse_permutation =
InversePermutation(old_reshape->dimensions());
operands[i] = instruction->parent()->AddInstruction(
HloInstruction::CreateTranspose(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i], inverse_permutation));
}
break;
}
case HloOpcode::kRng: {
CHECK_EQ(operands[i]->user_count(), 1);
StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
HloInstruction* instruction) {
if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
return false;
}
std::vector<HloInstruction*> operands = instruction->operands();
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
TF_RET_CHECK(old_reshape != nullptr);
Shape new_elementwise_shape = old_reshape->operand(0)->shape();
VLOG(3) << "** Trying to sink reshape or transpose: "
<< instruction->ToStringNoMetadata()
<< "\n\told reshape: " << old_reshape->ToStringNoMetadata()
<< "\n\tnew elementwise shape: "
<< ShapeUtil::HumanString(new_elementwise_shape);
for (size_t i = 0; i < operands.size(); ++i) {
// All scalar operands remain as-is, even if they're reshape or transpose,
// to simplify handling wrt special scalar broadcast rules for ops like
// Select. Scalar reshapes should be cheap anyways.
if (ShapeUtil::IsScalar(operands[i]->shape())) {
continue;
}
PrimitiveType element_type = operands[i]->shape().element_type();
switch (operands[i]->opcode()) {
case HloOpcode::kConstant: {
if (old_reshape->opcode() == HloOpcode::kReshape) {
VLOG(3) << "Creating reshape for kConstant operand " << i << ": "
<< operands[i]->ToStringNoMetadata();
operands[i] = instruction->parent()->AddInstruction(
operands[i]->CloneWithNewOperands(
HloInstruction::CreateReshape(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i]->operands()));
break;
operands[i]));
} else {
TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose);
std::vector<int64> inverse_permutation =
InversePermutation(old_reshape->dimensions());
operands[i] = instruction->parent()->AddInstruction(
HloInstruction::CreateTranspose(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i], inverse_permutation));
}
case HloOpcode::kReshape:
case HloOpcode::kTranspose:
operands[i] = operands[i]->mutable_operand(0);
break;
default:
LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
"transposes.";
}
}
if (HloOpcode::kFusion == instruction->opcode()) {
// Here we already know `instruction` is elementwise, and no operand is
// implicit broadcast as if it were the operands would not be equivalent
// reshapes, so all the fused instructions have the same dimensions.
for (const auto& fused_instruction : instruction->fused_instructions()) {
Shape* shape = fused_instruction->mutable_shape();
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
*shape->mutable_layout() = new_elementwise_shape.layout();
}
}
auto new_elementwise =
computation->AddInstruction(instruction->CloneWithNewOperands(
// `instruction` may change the element type, e.g., from
// operands[0] -> reshape -> convert (`instruction`)
// to
// operands[0] -> convert' -> reshape'
//
// In this case, convert' should have the same element type as
// `convert` and the same dimensions as operands[0].
ShapeUtil::ChangeElementType(new_elementwise_shape,
instruction->shape().element_type()),
operands));
std::unique_ptr<HloInstruction> new_reshape;
switch (old_reshape->opcode()) {
case HloOpcode::kReshape:
new_reshape = HloInstruction::CreateReshape(instruction->shape(),
new_elementwise);
break;
}
case HloOpcode::kRng: {
CHECK_EQ(operands[i]->user_count(), 1);
operands[i] = instruction->parent()->AddInstruction(
operands[i]->CloneWithNewOperands(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i]->operands()));
break;
}
case HloOpcode::kReshape:
case HloOpcode::kTranspose:
new_reshape = HloInstruction::CreateTranspose(
instruction->shape(), new_elementwise, old_reshape->dimensions());
operands[i] = operands[i]->mutable_operand(0);
break;
default:
LOG(FATAL) << "Bad opcode";
LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
"transposes.";
}
TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction,
std::move(new_reshape)));
return true;
}
return false;
if (HloOpcode::kFusion == instruction->opcode()) {
// Here we already know `instruction` is elementwise, and no operand is
// implicit broadcast as if it were the operands would not be equivalent
// reshapes, so all the fused instructions have the same dimensions.
for (const auto& fused_instruction : instruction->fused_instructions()) {
Shape* shape = fused_instruction->mutable_shape();
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
*shape->mutable_layout() = new_elementwise_shape.layout();
}
}
HloInstruction* new_elementwise =
computation->AddInstruction(instruction->CloneWithNewOperands(
// `instruction` may change the element type, e.g., from
// operands[0] -> reshape -> convert (`instruction`)
// to
// operands[0] -> convert' -> reshape'
//
// In this case, convert' should have the same element type as
// `convert` and the same dimensions as operands[0].
ShapeUtil::ChangeElementType(new_elementwise_shape,
instruction->shape().element_type()),
operands));
std::unique_ptr<HloInstruction> new_reshape;
switch (old_reshape->opcode()) {
case HloOpcode::kReshape:
VLOG(3) << "Creating new reshape for new elementwise op: "
<< new_elementwise->ToStringNoMetadata();
new_reshape =
HloInstruction::CreateReshape(instruction->shape(), new_elementwise);
break;
case HloOpcode::kTranspose:
new_reshape = HloInstruction::CreateTranspose(
instruction->shape(), new_elementwise, old_reshape->dimensions());
break;
default:
LOG(FATAL) << "Bad opcode";
}
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, std::move(new_reshape)));
return true;
}
} // namespace
@ -237,9 +287,9 @@ StatusOr<bool> ReshapeMover::Run(HloModule* module) {
bool changed = false;
for (const auto& comp : module->computations()) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (TrySinkReshapeOrTranspose(comp.get(), instruction)) {
changed = true;
}
TF_ASSIGN_OR_RETURN(bool did_change,
TrySinkReshapeOrTranspose(comp.get(), instruction));
changed |= did_change;
}
}
return changed;

View File

@ -234,6 +234,58 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
EXPECT_EQ(select, computation->root_instruction());
}
// Tree looks like:
//
// param0 [1,128,1]
// |
// reshape [128,1] constant [128,1024]
// \ /
// multiply w/implicit broadcast [128,1024]
//
// The reshape mover would like to sink the reshape below the multiply.
//
// Previously we would attempt to insert a reshape of the constant to [1,128,1]
// (which is unsound, because it has a different number of elements) as
// preparation for sinking the reshape.
//
// To eliminate the unsoundness, we outlaw reshape sinking when one of the
// operands is implicitly broadcast in the elementwise consumer.
//
// TODO(b/37799338) However, it would be possible in this case to do a more
// in-depth analysis to get reshape movement to occur:
//
// 1. Note that the broadcast dimension (logical dimension 1) in the operands
// would map back to logical dimension 2 in the param0 node.
// 2. Match rank of the constant to the param0 node (by prepending a trivial 1
// dimension).
// 3. Reshape to [128,1024] at the root.
//
// But this is not currently done.
TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0"));
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {128, 1}), param0));
Array2D<float> a(128, 1024);
auto literal = LiteralUtil::CreateR2FromArray2D<float>(a);
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
auto multiply = builder.AddInstruction(HloInstruction::CreateBinary(
constant->shape(), HloOpcode::kMultiply, constant, reshape));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Constant(), op::Reshape(param0)));
EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Constant(), op::Reshape(param0)));
EXPECT_EQ(multiply, computation->root_instruction());
}
// Tree looks like this:
//
// add1