[MLIR] Fix CreateView() to not drop slice offset incorrectly.
- A slice can be mapped to memref for the allocation only if slice offset is 0. - Minor comment tweaks. PiperOrigin-RevId: 334494277 Change-Id: Iac9251162b8470b8705eb0f2866d010cf83785b1
This commit is contained in:
parent
ca8309afca
commit
9fa8f4ad49
@ -294,7 +294,7 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
|
||||
assignment_.GetUniqueSlice(instr, *current_shape_index));
|
||||
Value alloc = allocations_[slice.allocation()];
|
||||
if (alloc.getType() == out_type) {
|
||||
if (alloc.getType() == out_type && slice.offset() == 0) {
|
||||
values->push_back(alloc);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -337,16 +337,16 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
|
||||
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
|
||||
SmallVectorImpl<Value>* values) {
|
||||
// Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have
|
||||
// gone fancier to do the following cacheing:
|
||||
// %range = ViewOp(%allocation, %offset) : memref<i8xSIZE>
|
||||
// %typed_range = ViewOp(%range) : memref<f32x...>
|
||||
// gone fancier to do the following caching:
|
||||
// %slice = ViewOp(%allocation, %offset) : memref<i8xSIZE>
|
||||
// %typed_slice = ViewOp(%slice) : memref<f32x...>
|
||||
//
|
||||
// where %range is cached. This in theory gives easier time for alias
|
||||
// analysis, since the identity of %range defines alias. However,
|
||||
// %typed_range can't be cached, as different buffers with different types and
|
||||
// where %slice is cached. This in theory gives easier time for alias
|
||||
// analysis, since the identity of %slice defines alias. However,
|
||||
// %typed_slice can't be cached, as different buffers with different types and
|
||||
// shapes may still alias. Creating two ViewOps doesn't seem to worth the
|
||||
// effort for a slightly easier aliasing, so we don't over optimize here.
|
||||
auto result = slices_.try_emplace(instr, llvm::SmallVector<Value, 4>{});
|
||||
auto result = slices_.try_emplace(instr, llvm::SmallVector<Value, 1>{});
|
||||
llvm::SmallVectorImpl<Value>& new_values = result.first->second;
|
||||
if (result.second) {
|
||||
::xla::ShapeIndex shape_index;
|
||||
@ -373,7 +373,7 @@ Status LhloDialectEmitter::Initialize() {
|
||||
|
||||
if (computation_.IsEntryComputation()) {
|
||||
// Sort the rather arbitrarily ordered allocations to match the input/output
|
||||
// parameters. Specifically We want to sort buffer allocations in the
|
||||
// parameters. Specifically we want to sort buffer allocations in the
|
||||
// following order:
|
||||
// * Parameters always order before non-parameters.
|
||||
// * Different parameters order by parameter number.
|
||||
|
@ -102,8 +102,7 @@ Shape TypeToShape(mlir::Type type) {
|
||||
if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
|
||||
return ShapeUtil::MakeShape(ptype, {});
|
||||
|
||||
if (type.isBF16() || type.isF32() || type.isF64() ||
|
||||
type.isa<mlir::IntegerType>()) {
|
||||
if (type.isIntOrFloat()) {
|
||||
auto* context = type.getContext();
|
||||
mlir::emitError(mlir::UnknownLoc::get(context))
|
||||
<< "lowering should have been handled by primitive type lowering for "
|
||||
|
@ -1396,16 +1396,14 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
|
||||
std::vector<xla::Shape> output_shapes(sort_op.output().size());
|
||||
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
operand_shapes[i] =
|
||||
TypeToShape(sort_op.operands()[i].getType().cast<mlir::MemRefType>());
|
||||
operand_shapes[i] = TypeToShape(sort_op.operands()[i].getType());
|
||||
}
|
||||
|
||||
// Craft n + 1 slices, where the first n are output parameters, and the last
|
||||
// is the on-device tuple storage. We don't need n operands because sorting
|
||||
// kernels are always in-place.
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
output_shapes[i] =
|
||||
TypeToShape(sort_op.output()[i].getType().cast<mlir::MemRefType>());
|
||||
output_shapes[i] = TypeToShape(sort_op.output()[i].getType());
|
||||
MlirBufferSlice slice;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
slice.buffer_slice,
|
||||
|
Loading…
Reference in New Issue
Block a user