[XLA] Enable tests to force result layouts for tuple results.

PiperOrigin-RevId: 204180483
This commit is contained in:
Michael Kuperstein 2018-07-11 13:43:24 -07:00 committed by TensorFlower Gardener
parent bbc23f229e
commit f71040900c
3 changed files with 19 additions and 0 deletions

View File

@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) {
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
}
void ShapeLayout::ResetLayout(const Layout& layout,
ShapeIndexView shape_index) {
CHECK(ShapeUtil::IsTuple(shape_));
*ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() =
layout;
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
}
bool ShapeLayout::operator==(const ShapeLayout& other) const {
return ShapeUtil::Equal(shape_, other.shape_);
}

View File

@ -72,6 +72,10 @@ class ShapeLayout {
// tuple.
void ResetLayout(const Layout& layout);
// Resets the layout on the shape at the provided ShapeIndex to the provided
// layout. Shape must be a tuple.
void ResetLayout(const Layout& layout, ShapeIndexView shape_index);
// Returns a string representation of this object.
string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); }

View File

@ -200,6 +200,13 @@ class HloTestBase : public ::testing::Test {
->ResetLayout(layout);
}
void ForceResultLayout(HloModule* module, const Layout& layout,
ShapeIndexView shape_index) {
module->mutable_entry_computation_layout()
->mutable_result_layout()
->ResetLayout(layout, shape_index);
}
// Convenience method to clear the layout of the computation result in
// 'module'.
void ForceClearResultLayout(HloModule* module) {