From 0d10d5d09721a6f258c35cd6c35ad5a32ee73f83 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 21 Aug 2020 13:24:59 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 327864391 Change-Id: Id021118bc279f646ec693ec4af3f1f59cb63c38e --- tensorflow/compiler/xla/service/layout_assignment.cc | 2 +- tensorflow/compiler/xla/service/layout_assignment.h | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index adccda79eac..55569cfde0e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1891,7 +1891,7 @@ Status LayoutAssignment::RunOnComputation( ? ShapeUtil::GetSubshape(instruction->literal().shape(), buffer.index()) .layout() - : LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + : GetUnconstrainedLayout(buffer); TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer, /*mandatory=*/false)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index a04d056c618..def620bcee9 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -338,6 +339,9 @@ class LayoutAssignment : public HloModulePass { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) { + return LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + } // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) {