relaxing requirements for clonewithnewoperands to preserve sharding information
Reviewers: #tensorflow!, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved! Differential Revision: https://phabricator.sourcevertex.net/D36993
This commit is contained in:
parent
4d1142b04b
commit
9be9961059
tensorflow/compiler/xla
@ -1497,11 +1497,12 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) {
|
||||
|
||||
void HloInstruction::SetupDerivedInstruction(
|
||||
HloInstruction* derived_instruction) const {
|
||||
if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
|
||||
shape_, derived_instruction->shape())) {
|
||||
// Only copy sharding if the shape of the two instruction is compatible
|
||||
// because copying it between differently shaped instructions can produce
|
||||
// invalid shardings.
|
||||
if (sharding_ != nullptr &&
|
||||
ShapeUtil::CompatibleIgnoringElementTypeAndDimensions(
|
||||
shape_, derived_instruction->shape())) {
|
||||
// Only copy sharding if the tuple tree shape of the two instruction is
|
||||
// compatible because copying it between differently shaped instructions
|
||||
// can produce invalid shardings.
|
||||
derived_instruction->set_sharding(*sharding_);
|
||||
} else {
|
||||
derived_instruction->clear_sharding();
|
||||
|
@ -749,6 +749,45 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
|
||||
EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) {
|
||||
|
||||
HloSharding sharding = HloSharding::AssignDevice(5);
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
})));
|
||||
auto* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
|
||||
tuple->set_sharding(sharding);
|
||||
// Compatible with original shape as tuple tree structure is identical
|
||||
auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
|
||||
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
|
||||
EXPECT_EQ(tuple_clone->sharding(), sharding);
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, DoNotPreserveShardingThroughIncompatibleClone) {
|
||||
|
||||
HloSharding sharding = HloSharding::AssignDevice(5);
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
})));
|
||||
auto* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
|
||||
tuple->set_sharding(sharding);
|
||||
// Incompatible with original shape as tuple tree structure is different
|
||||
auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape,
|
||||
clone_shape});
|
||||
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
|
||||
EXPECT_FALSE(tuple_clone->has_sharding());
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
// Create a fusion instruction containing a single unary operation.
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
@ -141,9 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
||||
return false;
|
||||
if (!ignore_dimensions_) {
|
||||
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ignore_layout_) {
|
||||
|
@ -220,6 +220,10 @@ class Shape {
|
||||
ignore_dynamic_dimension_ = true;
|
||||
return *this;
|
||||
}
|
||||
Equal& IgnoreDimensions() {
|
||||
ignore_dimensions_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
bool ignore_layout_ = false;
|
||||
@ -229,6 +233,7 @@ class Shape {
|
||||
bool ignore_element_type_ = false;
|
||||
bool ignore_fp_precision_ = false;
|
||||
bool ignore_dynamic_dimension_ = false;
|
||||
bool ignore_dimensions_ = false;
|
||||
};
|
||||
|
||||
// Test that all fields of the shape are the same, equivalent to Equal().
|
||||
|
@ -654,6 +654,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
.IgnoreLayout()(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringElementTypeAndDimensions(
|
||||
const Shape& lhs, const Shape& rhs) {
|
||||
return Shape::Equal().IgnoreElementType().IgnoreLayout().IgnoreDimensions()
|
||||
.IgnoreDynamicDimension()(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
return Shape::Equal()
|
||||
|
@ -293,6 +293,12 @@ class ShapeUtil {
|
||||
// compatibility.
|
||||
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns true if the tuple tree shapes are identical. Leaf dimensions,
|
||||
// element type, and layout are ignored. Tuple elements are compared
|
||||
// recursively for compatibility.
|
||||
static bool CompatibleIgnoringElementTypeAndDimensions(const Shape& lhs,
|
||||
const Shape& rhs);
|
||||
|
||||
// As Compatible, but allow one of lhs and rhs to be BF16 while the other
|
||||
// being F32. Tuple elements are compared recursively for compatibility.
|
||||
static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
||||
|
Loading…
Reference in New Issue
Block a user