[XLA] Enable tests to force result layouts for tuple results.
PiperOrigin-RevId: 204180483
This commit is contained in:
parent
bbc23f229e
commit
f71040900c
tensorflow/compiler/xla
@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) {
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
|
||||
}
|
||||
|
||||
void ShapeLayout::ResetLayout(const Layout& layout,
|
||||
ShapeIndexView shape_index) {
|
||||
CHECK(ShapeUtil::IsTuple(shape_));
|
||||
*ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() =
|
||||
layout;
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
|
||||
}
|
||||
|
||||
bool ShapeLayout::operator==(const ShapeLayout& other) const {
|
||||
return ShapeUtil::Equal(shape_, other.shape_);
|
||||
}
|
||||
|
@ -72,6 +72,10 @@ class ShapeLayout {
|
||||
// tuple.
|
||||
void ResetLayout(const Layout& layout);
|
||||
|
||||
// Resets the layout on the shape at the provided ShapeIndex to the provided
|
||||
// layout. Shape must be a tuple.
|
||||
void ResetLayout(const Layout& layout, ShapeIndexView shape_index);
|
||||
|
||||
// Returns a string representation of this object.
|
||||
string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); }
|
||||
|
||||
|
@ -200,6 +200,13 @@ class HloTestBase : public ::testing::Test {
|
||||
->ResetLayout(layout);
|
||||
}
|
||||
|
||||
void ForceResultLayout(HloModule* module, const Layout& layout,
|
||||
ShapeIndexView shape_index) {
|
||||
module->mutable_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->ResetLayout(layout, shape_index);
|
||||
}
|
||||
|
||||
// Convenience method to clear the layout of the computation result in
|
||||
// 'module'.
|
||||
void ForceClearResultLayout(HloModule* module) {
|
||||
|
Loading…
Reference in New Issue
Block a user