From e42a2b8dc9e172023759914304d5f451c7dc3950 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 6 Jan 2021 13:41:56 +0000 Subject: [PATCH] [ROCm] Fix for breakage in XLA Conv Op functionality The following commit breaks Conv Op functionality (in the XLA backend) for ROCm platform. https://github.com/tensorflow/tensorflow/commit/8684c6b2e95601542c6c5c006bde5dd50f589a50 The cause seems to be that the `scratch_size` field in the new `GpuConvDescriptor` is not getting correctly populated in the new MLIR path. It is being used correctly in the convolution runner code. declaration: https://github.com/tensorflow/tensorflow/commit/8684c6b2e95601542c6c5c006bde5dd50f589a50#diff-6453912dbc4ee715a56da9d7b218b52795dea2aa631a482101fc6d58c573d9ccR122-R135 use (get access) in conv runner: https://github.com/tensorflow/tensorflow/commit/8684c6b2e95601542c6c5c006bde5dd50f589a50#diff-a01181d08b28a9c7432f22439622f16725126184283a73822c70b2151098a8adR277 set access in non-MLIR(?) based path: https://github.com/tensorflow/tensorflow/commit/8684c6b2e95601542c6c5c006bde5dd50f589a50#diff-a01181d08b28a9c7432f22439622f16725126184283a73822c70b2151098a8adR450 This commit merely adds the missing "set" in the MLIR based path --- tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c860e565baf..a110255de03 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1122,6 +1122,7 @@ Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) { descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()), op.backend_config().result_layout()); descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers()); + descriptor.scratch_size = input.extra_slice->shape.tuple_shapes(1).dimensions(0); mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue(); mlir::DenseIntElementsAttr padding = op.padding().getValue(); mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();