relaxing requirements for clonewithnewoperands to preserve sharding information

Reviewers: #tensorflow!, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved!

Differential Revision: https://phabricator.sourcevertex.net/D36993
This commit is contained in:
Alfie Edwards 2020-11-30 14:22:04 +00:00
parent 4d1142b04b
commit 9be9961059
6 changed files with 67 additions and 8 deletions

View File

@ -1497,11 +1497,12 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) {
void HloInstruction::SetupDerivedInstruction(
HloInstruction* derived_instruction) const {
if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
shape_, derived_instruction->shape())) {
// Only copy sharding if the shape of the two instruction is compatible
// because copying it between differently shaped instructions can produce
// invalid shardings.
if (sharding_ != nullptr &&
ShapeUtil::CompatibleIgnoringElementTypeAndDimensions(
shape_, derived_instruction->shape())) {
// Only copy sharding if the tuple tree shape of the two instruction is
// compatible because copying it between differently shaped instructions
// can produce invalid shardings.
derived_instruction->set_sharding(*sharding_);
} else {
derived_instruction->clear_sharding();

View File

@ -749,6 +749,45 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
}
TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Compatible with original shape as tuple tree structure is identical
auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_EQ(tuple_clone->sharding(), sharding);
}
TEST_F(HloInstructionTest, DoNotPreserveShardingThroughIncompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Incompatible with original shape as tuple tree structure is different
auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape,
clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_FALSE(tuple_clone->has_sharding());
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});

View File

@ -141,9 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
}
}
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
if (!ignore_dimensions_) {
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
}
if (!ignore_layout_) {

View File

@ -220,6 +220,10 @@ class Shape {
ignore_dynamic_dimension_ = true;
return *this;
}
Equal& IgnoreDimensions() {
ignore_dimensions_ = true;
return *this;
}
private:
bool ignore_layout_ = false;
@ -229,6 +233,7 @@ class Shape {
bool ignore_element_type_ = false;
bool ignore_fp_precision_ = false;
bool ignore_dynamic_dimension_ = false;
bool ignore_dimensions_ = false;
};
// Test that all fields of the shape are the same, equivalent to Equal().

View File

@ -654,6 +654,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
.IgnoreLayout()(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementTypeAndDimensions(
const Shape& lhs, const Shape& rhs) {
return Shape::Equal().IgnoreElementType().IgnoreLayout().IgnoreDimensions()
.IgnoreDynamicDimension()(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
return Shape::Equal()

View File

@ -293,6 +293,12 @@ class ShapeUtil {
// compatibility.
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
// Returns true if the tuple tree shapes are identical. Leaf dimensions,
// element type, and layout are ignored. Tuple elements are compared
// recursively for compatibility.
static bool CompatibleIgnoringElementTypeAndDimensions(const Shape& lhs,
const Shape& rhs);
// As Compatible, but allow one of lhs and rhs to be BF16 while the other
// being F32. Tuple elements are compared recursively for compatibility.
static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);