[ROCm] Fix for breakage in XLA Conv Op functionality
The following commit breaks Conv Op functionality (in the XLA backend) for ROCm platform.8684c6b2e9
The cause seems to be that the `scratch_size` field in the new `GpuConvDescriptor` is not getting correctly populated in the new MLIR path. It is being used correctly in the convolution runner code. declaration:8684c6b2e9 (diff-6453912dbc4ee715a56da9d7b218b52795dea2aa631a482101fc6d58c573d9ccR122-R135)
use (get access) in conv runner:8684c6b2e9 (diff-a01181d08b28a9c7432f22439622f16725126184283a73822c70b2151098a8adR277)
set access in non-MLIR(?) based path:8684c6b2e9 (diff-a01181d08b28a9c7432f22439622f16725126184283a73822c70b2151098a8adR450)
This commit merely adds the missing "set" in the MLIR based path
This commit is contained in:
parent
2f4a5dffed
commit
e42a2b8dc9
@ -1122,6 +1122,7 @@ Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
|
||||
descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
|
||||
op.backend_config().result_layout());
|
||||
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
|
||||
descriptor.scratch_size = input.extra_slice->shape.tuple_shapes(1).dimensions(0);
|
||||
mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
|
||||
mlir::DenseIntElementsAttr padding = op.padding().getValue();
|
||||
mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
|
||||
|
Loading…
x
Reference in New Issue
Block a user