PiperOrigin-RevId: 311133851
Change-Id: I5e85bed33bf2295752f2f862d5a4295d5e2a2817
This commit is contained in:
A. Unique TensorFlower 2020-05-12 08:54:02 -07:00 committed by TensorFlower Gardener
parent 6bb9ca398d
commit 1658986a76
2 changed files with 16 additions and 13 deletions

View File

@ -8,7 +8,9 @@
// CHECK-SAME: ) {
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
// The only expected instruction is a copy from the input into the output.
// CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][][] : memref<16xi8> to memref<2x2xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C02:.*]] = constant 0 : index
// CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C02]]][] : memref<16xi8> to memref<2x2xf32>
// CHECK: xla_lhlo.copy
// CHECK-SAME: %[[ARG0]], %[[OUTPUT]]
return %value : tensor<2x2xf32>

View File

@ -251,17 +251,15 @@ Value LhloDialectEmitter::GetOrCreateView(
// Create the view for this slice size, possible with an affine map to model
// the offset. The result is cached in the slices_ map.
SmallVector<AffineMap, 1> offset_map;
if (slice.offset()) {
offset_map.push_back(AffineMap::get(
/*dimCount=*/1, /*symbolCount=*/0,
{getAffineDimExpr(0, builder_.getContext()) + slice.offset()},
builder_.getContext()));
}
auto slice_type = MemRefType::get({slice.size()}, i8_type_, offset_map);
// The std.view result type does not carry the static offset: this is not
// useful information. Rather, the view op must have the static offset.
auto slice_type = MemRefType::get({slice.size()}, i8_type_, {});
auto slice_view = builder_.create<ViewOp>(
alloc_buffer.getLoc(), slice_type, alloc_buffer, /*operands=*/llvm::None);
Value byte_shift =
builder_.create<ConstantIndexOp>(alloc_buffer.getLoc(), slice.offset());
auto slice_view =
builder_.create<ViewOp>(alloc_buffer.getLoc(), slice_type, alloc_buffer,
byte_shift, /*sizes=*/ArrayRef<Value>{});
slices_.insert({slice_key, slice_view});
return slice_view;
}
@ -277,9 +275,12 @@ StatusOr<Value> LhloDialectEmitter::GetOrCreateView(
Value slice_view = GetOrCreateView(out_slice);
TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType<MemRefType>(
target_shape, builder_));
Value byte_shift =
builder_.create<ConstantIndexOp>(builder_.getUnknownLoc(), 0);
if (slice_view.getType() != out_type)
slice_view = builder_.create<ViewOp>(builder_.getUnknownLoc(), out_type,
slice_view, llvm::None);
slice_view =
builder_.create<ViewOp>(builder_.getUnknownLoc(), out_type, slice_view,
byte_shift, /*sizes=*/ArrayRef<Value>{});
return slice_view;
}