[ROCm] Fix for breakage in XLA Conv Op functionality

The following commit breaks Conv Op functionality (in the XLA backend) for ROCm platform.

8684c6b2e9

The cause seems to be that the `scratch_size` field in the new `GpuConvDescriptor` is not getting correctly populated in the new MLIR path. It is being used correctly in the convolution runner code.

declaration:

8684c6b2e9 (diff-6453912dbc4ee715a56da9d7b218b52795dea2aa631a482101fc6d58c573d9ccR122-R135)

use (get access) in conv runner:

8684c6b2e9 (diff-a01181d08b28a9c7432f22439622f16725126184283a73822c70b2151098a8adR277)

set access in non-MLIR(?) based path:

8684c6b2e9 (diff-a01181d08b28a9c7432f22439622f16725126184283a73822c70b2151098a8adR450)

This commit merely adds the missing "set" in the MLIR based path
This commit is contained in:
Deven Desai 2021-01-06 13:41:56 +00:00
parent 2f4a5dffed
commit e42a2b8dc9

View File

@ -1122,6 +1122,7 @@ Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
op.backend_config().result_layout());
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
descriptor.scratch_size = input.extra_slice->shape.tuple_shapes(1).dimensions(0);
mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
mlir::DenseIntElementsAttr padding = op.padding().getValue();
mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();