[XLA] Enable tests to force result layouts for tuple results.
PiperOrigin-RevId: 204180483
This commit is contained in:
parent
bbc23f229e
commit
f71040900c
@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) {
|
|||||||
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
|
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ShapeLayout::ResetLayout(const Layout& layout,
|
||||||
|
ShapeIndexView shape_index) {
|
||||||
|
CHECK(ShapeUtil::IsTuple(shape_));
|
||||||
|
*ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() =
|
||||||
|
layout;
|
||||||
|
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
|
||||||
|
}
|
||||||
|
|
||||||
bool ShapeLayout::operator==(const ShapeLayout& other) const {
|
bool ShapeLayout::operator==(const ShapeLayout& other) const {
|
||||||
return ShapeUtil::Equal(shape_, other.shape_);
|
return ShapeUtil::Equal(shape_, other.shape_);
|
||||||
}
|
}
|
||||||
|
@ -72,6 +72,10 @@ class ShapeLayout {
|
|||||||
// tuple.
|
// tuple.
|
||||||
void ResetLayout(const Layout& layout);
|
void ResetLayout(const Layout& layout);
|
||||||
|
|
||||||
|
// Resets the layout on the shape at the provided ShapeIndex to the provided
|
||||||
|
// layout. Shape must be a tuple.
|
||||||
|
void ResetLayout(const Layout& layout, ShapeIndexView shape_index);
|
||||||
|
|
||||||
// Returns a string representation of this object.
|
// Returns a string representation of this object.
|
||||||
string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); }
|
string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); }
|
||||||
|
|
||||||
|
@ -200,6 +200,13 @@ class HloTestBase : public ::testing::Test {
|
|||||||
->ResetLayout(layout);
|
->ResetLayout(layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ForceResultLayout(HloModule* module, const Layout& layout,
|
||||||
|
ShapeIndexView shape_index) {
|
||||||
|
module->mutable_entry_computation_layout()
|
||||||
|
->mutable_result_layout()
|
||||||
|
->ResetLayout(layout, shape_index);
|
||||||
|
}
|
||||||
|
|
||||||
// Convenience method to clear the layout of the computation result in
|
// Convenience method to clear the layout of the computation result in
|
||||||
// 'module'.
|
// 'module'.
|
||||||
void ForceClearResultLayout(HloModule* module) {
|
void ForceClearResultLayout(HloModule* module) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user