diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index adccda79eac..55569cfde0e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1891,7 +1891,7 @@ Status LayoutAssignment::RunOnComputation( ? ShapeUtil::GetSubshape(instruction->literal().shape(), buffer.index()) .layout() - : LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + : GetUnconstrainedLayout(buffer); TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer, /*mandatory=*/false)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index a04d056c618..def620bcee9 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -338,6 +339,9 @@ class LayoutAssignment : public HloModulePass { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) { + return LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + } // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) {