Merge pull request #46219 from ROCmSoftwarePlatform:google_upstream_rocm_xla_conv_fix

PiperOrigin-RevId: 350551602
Change-Id: Idd671009b63261f3973a3ce15cc124562a825450
This commit is contained in:
TensorFlower Gardener 2021-01-07 07:08:10 -08:00
commit 714d3ed498

View File

@ -1122,6 +1122,8 @@ Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
op.backend_config().result_layout());
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
descriptor.scratch_size =
input.extra_slice->shape.tuple_shapes(1).dimensions(0);
mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
mlir::DenseIntElementsAttr padding = op.padding().getValue();
mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();