[XLA] Make ReshapeMover account for broadcast operands, add VLOGging for debug.
Change: 154637127
This commit is contained in:
parent
a25509eda3
commit
7477074984
@ -859,7 +859,9 @@ cc_library(
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -410,7 +410,9 @@ HloInstruction::CreateSelectAndScatter(
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
|
||||
const Shape& shape, HloInstruction* operand) {
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(shape),
|
||||
ShapeUtil::ElementsIn(operand->shape()));
|
||||
ShapeUtil::ElementsIn(operand->shape()))
|
||||
<< "shape: " << ShapeUtil::HumanString(shape)
|
||||
<< " operand: " << ShapeUtil::HumanString(operand->shape());
|
||||
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
|
||||
instruction->AppendOperand(operand);
|
||||
return instruction;
|
||||
@ -1428,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const {
|
||||
return opc_name;
|
||||
}
|
||||
|
||||
string HloInstruction::ToString(bool compact_operands) const {
|
||||
string HloInstruction::ToString(bool compact_operands,
|
||||
bool include_metadata) const {
|
||||
string operands;
|
||||
if (opcode() == HloOpcode::kConstant) {
|
||||
// For constants, show the actual value in place of an empty operand list.
|
||||
@ -1509,8 +1512,9 @@ string HloInstruction::ToString(bool compact_operands) const {
|
||||
if (opcode() == HloOpcode::kGetTupleElement) {
|
||||
StrAppend(&extra, ", index=", tuple_index());
|
||||
}
|
||||
if (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
|
||||
!metadata_.source_file().empty()) {
|
||||
if (include_metadata &&
|
||||
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
|
||||
!metadata_.source_file().empty())) {
|
||||
StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
|
||||
}
|
||||
|
||||
|
@ -489,7 +489,10 @@ class HloInstruction {
|
||||
string SignatureString() const;
|
||||
|
||||
// Returns a debugging string that represents this instruction.
|
||||
string ToString(bool compact_operands = false) const;
|
||||
string ToString(bool compact_operands = false,
|
||||
bool include_metadata = true) const;
|
||||
|
||||
string ToStringNoMetadata() const { return ToString(false, false); }
|
||||
|
||||
// As ToString, but returns a shorter string.
|
||||
string ToShortString() const;
|
||||
|
@ -13,17 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Implementation note:
|
||||
//
|
||||
// The general idea behind this pass is that we're converting from this:
|
||||
// %param.A = OldShape
|
||||
// %param.B = OldShape
|
||||
@ -44,6 +35,19 @@ namespace {
|
||||
// only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or
|
||||
// transposes to a scalar should be cheap, we simply never move them.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Finds the first non-scalar operand of an instruction that is a reshape or
|
||||
// transpose and returns the operand if it is found or nullptr if not found.
|
||||
HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
|
||||
@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
|
||||
if (!ShapeUtil::IsScalar(operand->shape()) &&
|
||||
(operand->opcode() == HloOpcode::kReshape ||
|
||||
operand->opcode() == HloOpcode::kTranspose)) {
|
||||
VLOG(5) << "Found first non-scalar reshape operand of "
|
||||
<< hlo->ToStringNoMetadata() << ":\n\t"
|
||||
<< operand->ToStringNoMetadata();
|
||||
return operand;
|
||||
}
|
||||
}
|
||||
@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction,
|
||||
// A constant can trivially reshape the literal it holds.
|
||||
if (operand->opcode() == HloOpcode::kConstant &&
|
||||
ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
|
||||
VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: "
|
||||
<< operand->ToStringNoMetadata()
|
||||
<< "\n\tinstruction: " << instruction->ToStringNoMetadata();
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -116,119 +126,159 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes(
|
||||
if (!first_reshape_operand) {
|
||||
return false;
|
||||
}
|
||||
return (instruction->user_count() > 0 ||
|
||||
instruction == instruction->parent()->root_instruction()) &&
|
||||
instruction->IsElementwise() && !operands.empty() &&
|
||||
// Check whether all operands:
|
||||
// 1. are all reshapes or transposes that have the same input and
|
||||
// output shapes as all other reshaped or transposed operands.
|
||||
// or
|
||||
// 2. can be any shape like kConstant, kRng, and scalars.
|
||||
std::all_of(
|
||||
operands.begin(), operands.end(),
|
||||
[instruction,
|
||||
first_reshape_operand](const HloInstruction* operand) {
|
||||
return AreEquivalentReshapes(first_reshape_operand, operand) ||
|
||||
OperandCanTrivallyChangeShape(instruction, operand);
|
||||
});
|
||||
VLOG(3) << "** Checking whether instruction is an elementwise operation of "
|
||||
"equivalent reshapes/transposes: "
|
||||
<< instruction->ToStringNoMetadata();
|
||||
bool result =
|
||||
(instruction->user_count() > 0 ||
|
||||
instruction == instruction->parent()->root_instruction()) &&
|
||||
instruction->IsElementwise() && !operands.empty() &&
|
||||
// Check whether all operands:
|
||||
// 0. Have the same dimensions as the output -- if not, it may be
|
||||
// implicitly broadcast, which can confound the movement's
|
||||
// correctness.
|
||||
// 1. Are all reshapes or transposes that have the same input and
|
||||
// output shapes as all other reshaped or transposed operands.
|
||||
// or
|
||||
// 2. Can be any shape like kConstant, kRng, and scalars.
|
||||
std::all_of(
|
||||
operands.begin(), operands.end(),
|
||||
[instruction, first_reshape_operand](const HloInstruction* operand) {
|
||||
if (!ShapeUtil::SameDimensions(operand->shape(),
|
||||
instruction->shape())) {
|
||||
VLOG(5) << "Operand shape differs from output shape; may be "
|
||||
"implicitly broadcast, so preventing "
|
||||
"movement\n\toperand: "
|
||||
<< operand->ToStringNoMetadata() << "\n\tinstruction: "
|
||||
<< instruction->ToStringNoMetadata();
|
||||
return false;
|
||||
}
|
||||
if (AreEquivalentReshapes(first_reshape_operand, operand)) {
|
||||
VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
|
||||
<< first_reshape_operand->ToStringNoMetadata()
|
||||
<< "\n\toperand: " << operand->ToStringNoMetadata();
|
||||
return true;
|
||||
}
|
||||
if (OperandCanTrivallyChangeShape(instruction, operand)) {
|
||||
VLOG(5) << "Operand can trivially change shape: "
|
||||
<< operand->ToStringNoMetadata();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for "
|
||||
<< instruction->ToStringNoMetadata() << ": " << result;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Try to sink any reshape or transpose operands of `instruction` across it. We
|
||||
// do so if `instruction` is elementwise and all operands are equivalent
|
||||
// reshapes or transposes.
|
||||
bool TrySinkReshapeOrTranspose(HloComputation* computation,
|
||||
HloInstruction* instruction) {
|
||||
if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
|
||||
std::vector<HloInstruction*> operands = instruction->operands();
|
||||
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
|
||||
CHECK(old_reshape != nullptr);
|
||||
Shape new_elementwise_shape = old_reshape->operand(0)->shape();
|
||||
for (size_t i = 0; i < operands.size(); ++i) {
|
||||
// All scalar operands remain as-is, even if they're reshape or transpose,
|
||||
// to simplify handling wrt special scalar broadcast rules for ops like
|
||||
// Select. Scalar reshapes should be cheap anyways.
|
||||
if (ShapeUtil::IsScalar(operands[i]->shape())) {
|
||||
continue;
|
||||
}
|
||||
auto element_type = operands[i]->shape().element_type();
|
||||
switch (operands[i]->opcode()) {
|
||||
case HloOpcode::kConstant: {
|
||||
if (old_reshape->opcode() == HloOpcode::kReshape) {
|
||||
operands[i] = instruction->parent()->AddInstruction(
|
||||
HloInstruction::CreateReshape(
|
||||
ShapeUtil::ChangeElementType(new_elementwise_shape,
|
||||
element_type),
|
||||
operands[i]));
|
||||
} else {
|
||||
CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose);
|
||||
std::vector<int64> inverse_permutation =
|
||||
InversePermutation(old_reshape->dimensions());
|
||||
operands[i] = instruction->parent()->AddInstruction(
|
||||
HloInstruction::CreateTranspose(
|
||||
ShapeUtil::ChangeElementType(new_elementwise_shape,
|
||||
element_type),
|
||||
operands[i], inverse_permutation));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kRng: {
|
||||
CHECK_EQ(operands[i]->user_count(), 1);
|
||||
StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
|
||||
HloInstruction* instruction) {
|
||||
if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> operands = instruction->operands();
|
||||
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
|
||||
TF_RET_CHECK(old_reshape != nullptr);
|
||||
Shape new_elementwise_shape = old_reshape->operand(0)->shape();
|
||||
|
||||
VLOG(3) << "** Trying to sink reshape or transpose: "
|
||||
<< instruction->ToStringNoMetadata()
|
||||
<< "\n\told reshape: " << old_reshape->ToStringNoMetadata()
|
||||
<< "\n\tnew elementwise shape: "
|
||||
<< ShapeUtil::HumanString(new_elementwise_shape);
|
||||
for (size_t i = 0; i < operands.size(); ++i) {
|
||||
// All scalar operands remain as-is, even if they're reshape or transpose,
|
||||
// to simplify handling wrt special scalar broadcast rules for ops like
|
||||
// Select. Scalar reshapes should be cheap anyways.
|
||||
if (ShapeUtil::IsScalar(operands[i]->shape())) {
|
||||
continue;
|
||||
}
|
||||
PrimitiveType element_type = operands[i]->shape().element_type();
|
||||
switch (operands[i]->opcode()) {
|
||||
case HloOpcode::kConstant: {
|
||||
if (old_reshape->opcode() == HloOpcode::kReshape) {
|
||||
VLOG(3) << "Creating reshape for kConstant operand " << i << ": "
|
||||
<< operands[i]->ToStringNoMetadata();
|
||||
operands[i] = instruction->parent()->AddInstruction(
|
||||
operands[i]->CloneWithNewOperands(
|
||||
HloInstruction::CreateReshape(
|
||||
ShapeUtil::ChangeElementType(new_elementwise_shape,
|
||||
element_type),
|
||||
operands[i]->operands()));
|
||||
break;
|
||||
operands[i]));
|
||||
} else {
|
||||
TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose);
|
||||
std::vector<int64> inverse_permutation =
|
||||
InversePermutation(old_reshape->dimensions());
|
||||
operands[i] = instruction->parent()->AddInstruction(
|
||||
HloInstruction::CreateTranspose(
|
||||
ShapeUtil::ChangeElementType(new_elementwise_shape,
|
||||
element_type),
|
||||
operands[i], inverse_permutation));
|
||||
}
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kTranspose:
|
||||
operands[i] = operands[i]->mutable_operand(0);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
|
||||
"transposes.";
|
||||
}
|
||||
}
|
||||
if (HloOpcode::kFusion == instruction->opcode()) {
|
||||
// Here we already know `instruction` is elementwise, and no operand is
|
||||
// implicit broadcast as if it were the operands would not be equivalent
|
||||
// reshapes, so all the fused instructions have the same dimensions.
|
||||
for (const auto& fused_instruction : instruction->fused_instructions()) {
|
||||
Shape* shape = fused_instruction->mutable_shape();
|
||||
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
|
||||
*shape->mutable_layout() = new_elementwise_shape.layout();
|
||||
}
|
||||
}
|
||||
auto new_elementwise =
|
||||
computation->AddInstruction(instruction->CloneWithNewOperands(
|
||||
// `instruction` may change the element type, e.g., from
|
||||
// operands[0] -> reshape -> convert (`instruction`)
|
||||
// to
|
||||
// operands[0] -> convert' -> reshape'
|
||||
//
|
||||
// In this case, convert' should have the same element type as
|
||||
// `convert` and the same dimensions as operands[0].
|
||||
ShapeUtil::ChangeElementType(new_elementwise_shape,
|
||||
instruction->shape().element_type()),
|
||||
operands));
|
||||
std::unique_ptr<HloInstruction> new_reshape;
|
||||
switch (old_reshape->opcode()) {
|
||||
case HloOpcode::kReshape:
|
||||
new_reshape = HloInstruction::CreateReshape(instruction->shape(),
|
||||
new_elementwise);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kRng: {
|
||||
CHECK_EQ(operands[i]->user_count(), 1);
|
||||
operands[i] = instruction->parent()->AddInstruction(
|
||||
operands[i]->CloneWithNewOperands(
|
||||
ShapeUtil::ChangeElementType(new_elementwise_shape,
|
||||
element_type),
|
||||
operands[i]->operands()));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kTranspose:
|
||||
new_reshape = HloInstruction::CreateTranspose(
|
||||
instruction->shape(), new_elementwise, old_reshape->dimensions());
|
||||
operands[i] = operands[i]->mutable_operand(0);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Bad opcode";
|
||||
LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
|
||||
"transposes.";
|
||||
}
|
||||
TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction,
|
||||
std::move(new_reshape)));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
if (HloOpcode::kFusion == instruction->opcode()) {
|
||||
// Here we already know `instruction` is elementwise, and no operand is
|
||||
// implicit broadcast as if it were the operands would not be equivalent
|
||||
// reshapes, so all the fused instructions have the same dimensions.
|
||||
for (const auto& fused_instruction : instruction->fused_instructions()) {
|
||||
Shape* shape = fused_instruction->mutable_shape();
|
||||
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
|
||||
*shape->mutable_layout() = new_elementwise_shape.layout();
|
||||
}
|
||||
}
|
||||
HloInstruction* new_elementwise =
|
||||
computation->AddInstruction(instruction->CloneWithNewOperands(
|
||||
// `instruction` may change the element type, e.g., from
|
||||
// operands[0] -> reshape -> convert (`instruction`)
|
||||
// to
|
||||
// operands[0] -> convert' -> reshape'
|
||||
//
|
||||
// In this case, convert' should have the same element type as
|
||||
// `convert` and the same dimensions as operands[0].
|
||||
ShapeUtil::ChangeElementType(new_elementwise_shape,
|
||||
instruction->shape().element_type()),
|
||||
operands));
|
||||
|
||||
std::unique_ptr<HloInstruction> new_reshape;
|
||||
switch (old_reshape->opcode()) {
|
||||
case HloOpcode::kReshape:
|
||||
VLOG(3) << "Creating new reshape for new elementwise op: "
|
||||
<< new_elementwise->ToStringNoMetadata();
|
||||
new_reshape =
|
||||
HloInstruction::CreateReshape(instruction->shape(), new_elementwise);
|
||||
break;
|
||||
case HloOpcode::kTranspose:
|
||||
new_reshape = HloInstruction::CreateTranspose(
|
||||
instruction->shape(), new_elementwise, old_reshape->dimensions());
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Bad opcode";
|
||||
}
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
|
||||
instruction, std::move(new_reshape)));
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -237,9 +287,9 @@ StatusOr<bool> ReshapeMover::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (const auto& comp : module->computations()) {
|
||||
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
|
||||
if (TrySinkReshapeOrTranspose(comp.get(), instruction)) {
|
||||
changed = true;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(bool did_change,
|
||||
TrySinkReshapeOrTranspose(comp.get(), instruction));
|
||||
changed |= did_change;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
|
@ -234,6 +234,58 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
|
||||
EXPECT_EQ(select, computation->root_instruction());
|
||||
}
|
||||
|
||||
// Tree looks like:
|
||||
//
|
||||
// param0 [1,128,1]
|
||||
// |
|
||||
// reshape [128,1] constant [128,1024]
|
||||
// \ /
|
||||
// multiply w/implicit broadcast [128,1024]
|
||||
//
|
||||
// The reshape mover would like to sink the reshape below the multiply.
|
||||
//
|
||||
// Previously we would attempt to insert a reshape of the constant to [1,128,1]
|
||||
// (which is unsound, because it has a different number of elements) as
|
||||
// preparation for sinking the reshape.
|
||||
//
|
||||
// To eliminate the unsoundness, we outlaw reshape sinking when one of the
|
||||
// operands is implicitly broadcast in the elementwise consumer.
|
||||
//
|
||||
// TODO(b/37799338) However, it would be possible in this case to do a more
|
||||
// in-depth analysis to get reshape movement to occur:
|
||||
//
|
||||
// 1. Note that the broadcast dimension (logical dimension 1) in the operands
|
||||
// would map back to logical dimension 2 in the param0 node.
|
||||
// 2. Match rank of the constant to the param0 node (by prepending a trivial 1
|
||||
// dimension).
|
||||
// 3. Reshape to [128,1024] at the root.
|
||||
//
|
||||
// But this is not currently done.
|
||||
TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0"));
|
||||
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {128, 1}), param0));
|
||||
Array2D<float> a(128, 1024);
|
||||
auto literal = LiteralUtil::CreateR2FromArray2D<float>(a);
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
auto multiply = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
constant->shape(), HloOpcode::kMultiply, constant, reshape));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Multiply(op::Constant(), op::Reshape(param0)));
|
||||
|
||||
EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Multiply(op::Constant(), op::Reshape(param0)));
|
||||
EXPECT_EQ(multiply, computation->root_instruction());
|
||||
}
|
||||
|
||||
// Tree looks like this:
|
||||
//
|
||||
// add1
|
||||
|
Loading…
Reference in New Issue
Block a user