diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 03e4daa4d80..b26b241046f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1122,6 +1122,8 @@ Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) { descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()), op.backend_config().result_layout()); descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers()); + descriptor.scratch_size = + input.extra_slice->shape.tuple_shapes(1).dimensions(0); mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue(); mlir::DenseIntElementsAttr padding = op.padding().getValue(); mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();