[MLIR] Fix CreateView() to not drop slice offset incorrectly.

- A slice can be mapped to memref for the allocation only if slice offset is 0.
- Minor comment tweaks.

PiperOrigin-RevId: 334494277
Change-Id: Iac9251162b8470b8705eb0f2866d010cf83785b1
This commit is contained in:
Rahul Joshi 2020-09-29 17:11:23 -07:00 committed by TensorFlower Gardener
parent ca8309afca
commit 9fa8f4ad49
3 changed files with 12 additions and 15 deletions

View File

@ -294,7 +294,7 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(instr, *current_shape_index));
Value alloc = allocations_[slice.allocation()];
if (alloc.getType() == out_type) {
if (alloc.getType() == out_type && slice.offset() == 0) {
values->push_back(alloc);
return Status::OK();
}
@ -337,16 +337,16 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr,
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
SmallVectorImpl<Value>* values) {
// Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have
// gone fancier to do the following cacheing:
// %range = ViewOp(%allocation, %offset) : memref<i8xSIZE>
// %typed_range = ViewOp(%range) : memref<f32x...>
// gone fancier to do the following caching:
// %slice = ViewOp(%allocation, %offset) : memref<i8xSIZE>
// %typed_slice = ViewOp(%slice) : memref<f32x...>
//
// where %range is cached. This in theory gives easier time for alias
// analysis, since the identity of %range defines alias. However,
// %typed_range can't be cached, as different buffers with different types and
// where %slice is cached. This in theory gives easier time for alias
// analysis, since the identity of %slice defines alias. However,
// %typed_slice can't be cached, as different buffers with different types and
// shapes may still alias. Creating two ViewOps doesn't seem to worth the
// effort for a slightly easier aliasing, so we don't over optimize here.
auto result = slices_.try_emplace(instr, llvm::SmallVector<Value, 4>{});
auto result = slices_.try_emplace(instr, llvm::SmallVector<Value, 1>{});
llvm::SmallVectorImpl<Value>& new_values = result.first->second;
if (result.second) {
::xla::ShapeIndex shape_index;
@ -373,7 +373,7 @@ Status LhloDialectEmitter::Initialize() {
if (computation_.IsEntryComputation()) {
// Sort the rather arbitrarily ordered allocations to match the input/output
// parameters. Specifically We want to sort buffer allocations in the
// parameters. Specifically we want to sort buffer allocations in the
// following order:
// * Parameters always order before non-parameters.
// * Different parameters order by parameter number.

View File

@ -102,8 +102,7 @@ Shape TypeToShape(mlir::Type type) {
if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
return ShapeUtil::MakeShape(ptype, {});
if (type.isBF16() || type.isF32() || type.isF64() ||
type.isa<mlir::IntegerType>()) {
if (type.isIntOrFloat()) {
auto* context = type.getContext();
mlir::emitError(mlir::UnknownLoc::get(context))
<< "lowering should have been handled by primitive type lowering for "

View File

@ -1396,16 +1396,14 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
std::vector<xla::Shape> output_shapes(sort_op.output().size());
for (int i = 0; i < operand_count; i++) {
operand_shapes[i] =
TypeToShape(sort_op.operands()[i].getType().cast<mlir::MemRefType>());
operand_shapes[i] = TypeToShape(sort_op.operands()[i].getType());
}
// Craft n + 1 slices, where the first n are output parameters, and the last
// is the on-device tuple storage. We don't need n operands because sorting
// kernels are always in-place.
for (int i = 0; i < operand_count; i++) {
output_shapes[i] =
TypeToShape(sort_op.output()[i].getType().cast<mlir::MemRefType>());
output_shapes[i] = TypeToShape(sort_op.output()[i].getType());
MlirBufferSlice slice;
TF_ASSIGN_OR_RETURN(
slice.buffer_slice,