Internal change
PiperOrigin-RevId: 327864391 Change-Id: Id021118bc279f646ec693ec4af3f1f59cb63c38e
This commit is contained in:
parent
f8d80a78a3
commit
0d10d5d097
@ -1891,7 +1891,7 @@ Status LayoutAssignment::RunOnComputation(
|
|||||||
? ShapeUtil::GetSubshape(instruction->literal().shape(),
|
? ShapeUtil::GetSubshape(instruction->literal().shape(),
|
||||||
buffer.index())
|
buffer.index())
|
||||||
.layout()
|
.layout()
|
||||||
: LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
|
: GetUnconstrainedLayout(buffer);
|
||||||
TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
|
TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
|
||||||
/*mandatory=*/false));
|
/*mandatory=*/false));
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
@ -338,6 +339,9 @@ class LayoutAssignment : public HloModulePass {
|
|||||||
const ResultLayoutConstraint& layout_constraint,
|
const ResultLayoutConstraint& layout_constraint,
|
||||||
LayoutConstraints* constraints);
|
LayoutConstraints* constraints);
|
||||||
|
|
||||||
|
virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) {
|
||||||
|
return LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
|
||||||
|
}
|
||||||
// Called after layouts of an instruction have been finalized to allow
|
// Called after layouts of an instruction have been finalized to allow
|
||||||
// subclasses to check for platform specific assumptions.
|
// subclasses to check for platform specific assumptions.
|
||||||
virtual Status Verify(const HloInstruction* instruction) {
|
virtual Status Verify(const HloInstruction* instruction) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user