diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index a49985f0446..0f0ea50fde9 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -68,21 +68,21 @@ def ConfigsToTest(): Tuple (input_size, filter_size, out_size, stride, padding), the depthwise convolution parameters. """ - input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8], - [4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2], - [3, 299, 299, 3], [5, 183, 183, 1]] - filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1], - [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, - 8], [5, 5, 1, 2]] - out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8], - [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16], + input_sizes = [[4, 5, 5, 48], [2, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], + [4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2], + [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]] + filter_sizes = [[1, 1, 48, 2], [2, 2, 48, 8], [1, 3, 84, 1], [3, 1, 48, 4], + [3, 3, 8, 1], [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], + [2, 2, 3, 8], [5, 5, 1, 2]] + out_sizes = [[4, 5, 5, 96], [2, 5, 5, 384], [4, 8, 8, 84], [4, 17, 17, 192], + [4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]] - strides = [1, 1, 1, 1, 1, 1, 3, 2, 2] + strides = [1, 1, 1, 1, 1, 1, 1, 3, 2, 2] # pylint: disable=invalid-name VALID = "VALID" SAME = "SAME" # pylint: enable=invalid-name - paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] + paddings = [SAME, SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, paddings): yield i, f, o, s, p diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 4f79ce109fb..dda0d79337a 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -512,22 +512,26 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims), feature_group_count = in_depth / filter_in_depth; + // In the case of depthwise convolutions, the computation can be done by the + // batch_group_count parameter. + bool use_batch_group_count = in_depth > 1 && in_depth == filter_in_depth && + (feature_group_count != 1 || attrs.depthwise); + + if (use_batch_group_count) { + feature_group_count = 1; + } + // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we need to: // 1. In the case of group convolution, move the num_groups dimension before // the batch dimension // 2. Swap the roles of the batch and feature dimensions. - if (feature_group_count != 1 && !attrs.depthwise) { + if (!use_batch_group_count && feature_group_count != 1 && !attrs.depthwise) { activations = TransposeInputForGroupConvolutionBackpropFilter( activations, input_shape, feature_group_count, n_dim, c_dim); } - // In the case of depthwise convolution with no multiplier, - // the computation can be done by the batch_group_count parameter. - bool use_batch_group_count = - filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; - std::vector> padding(attrs.num_spatial_dims); std::vector rhs_dilation(attrs.num_spatial_dims); std::vector window_strides(attrs.num_spatial_dims); diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index f942d6768df..06bcd773f44 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -218,14 +218,127 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { int64 input_batch_dimension = dim_numbers.input_batch_dimension(); int64 output_batch_dimension = dim_numbers.output_batch_dimension(); + const int64 kernel_output_feature_dimension = + dim_numbers.kernel_output_feature_dimension(); int64 output_feature_dimension = dim_numbers.output_feature_dimension(); int64 input_batch = activation->shape().dimensions(input_batch_dimension); + const int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dimension); + + VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution); + const bool cost_too_high = !is_cost_viable_(convolution); + + if (output_feature != batch_group_count) { + const int64 group_size = output_feature / batch_group_count; + + VLOG(2) << "Need to insert a spatial dimension in activations and in the " + "kernel to deal with backprop of grouped convolutions " + << " group size " << group_size; + + // Add spatial dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_activation_shape); + const int64 new_spatial_dim = + reshaped_activation_shape.dimensions().size() - 1; + + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); + + // Insert new spatial dimension after the output feature dimension on the + // kernel. + auto dims = filter->shape().dimensions(); + std::vector new_dims; + for (int i = 0; i < dims.size(); i++) { + if (i == kernel_output_feature_dimension) { + new_dims.push_back(batch_group_count); + new_dims.push_back(group_size); + } else { + new_dims.push_back(dims[i]); + } + } + + Shape reshaped_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( + filter->shape().element_type(), new_dims); + + filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + // Edit convolution dimension numbers. Note that kernel_input_feature_dim + // now becomes a spatial dimension, and the newly added dimension of size + // 1 is the new kernel_input_feature_dim. + dim_numbers.add_input_spatial_dimensions(new_spatial_dim); + + // Update spatial dimension numbers if they show up after the newly added + // spatial dimension. + for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) { + if (d > kernel_output_feature_dimension) { + ++d; + } + } + + // Same for input feature dimension. + if (dim_numbers.kernel_input_feature_dimension() > + kernel_output_feature_dimension) { + dim_numbers.set_kernel_input_feature_dimension( + dim_numbers.kernel_input_feature_dimension() + 1); + } + + dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension + + 1); + + dim_numbers.add_output_spatial_dimensions(output_batch_dimension); + + dim_numbers.set_output_batch_dimension(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(group_size); + dim->set_padding_high(group_size - 1); + dim->set_padding_low(group_size - 1); + dim->set_window_reversal(false); + + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, /*feature_group_count=*/1, + batch_group_count, new_window, dim_numbers, + convolution->precision_config())); + + VLOG(2) << "New convolution " << new_convolution->ToString(); + + // This reversal is not done via set_window_reversal because GPUs don't + // support it. + auto rev = add(HloInstruction::CreateReverse( + new_output_shape, new_convolution, {output_batch_dimension})); + + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = + ShapeUtil::DeleteDimension(new_spatial_dim, rev->shape()); + auto reshaped_convolution = + HloInstruction::CreateReshape(reshaped_convolution_shape, rev); + + VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString(); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reshaped_convolution))); + + changed_ = true; + + convolution = new_convolution; + dim_numbers = convolution->convolution_dimension_numbers(); + output_batch_dimension = new_spatial_dim; + } + // We are not yet supporting batch_group of sizes greater than 1. TF_RET_CHECK(input_batch == batch_group_count); - if (!is_cost_viable_(convolution) || filter_expansion_) { + if (cost_too_high || filter_expansion_) { // We first obtain the expanded the filter (which is the convolution // output). The batch dimension is the expanded one (which originally // represents kernel input feature dimension). We mask the filter to zero @@ -238,11 +351,17 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { auto expanded_filter_shape = ExpandedFilterShape( convolution->shape(), batch_group_count, output_batch_dimension); + VLOG(2) << "output_batch_dimension " << output_batch_dimension; + VLOG(2) << "New output shape of convolution " + << expanded_filter_shape.ToString(); + auto new_convolution = add(HloInstruction::CreateConvolve( expanded_filter_shape, activation, filter, /*feature_group_count=*/1, /*batch_group_count=*/1, convolution->window(), dim_numbers, convolution->precision_config())); + VLOG(2) << "Expanded convolution " << new_convolution->ToString(); + auto zero = add(HloInstruction::CreateConstant( LiteralUtil::Zero(expanded_filter_shape.element_type()))); auto zero_filter = @@ -354,6 +473,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { changed_ = false; return Status::OK(); } + VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution); // We want to repeat 'filter' in the 'input_feature_dim' dimension // 'group_count' times. if (!is_cost_viable_(convolution) || filter_expansion_) { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 283959e73ca..1d7f9faea38 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1116,6 +1116,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:depthwise_convolution_converter", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:dump", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 30b204e6fd5..6709a51b849 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dump.h" @@ -138,11 +139,28 @@ Status GpuCompiler::OptimizeHloModule( // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); + + pipeline.AddPass(); + + // We use the ConvolutionGroupConverter to convert backprops of filter + // grouped convolutions into non-grouped equivalents. + auto batch_group_cost_model = [](HloInstruction* conv) { + auto dim_numbers = conv->convolution_dimension_numbers(); + const int64 input_batch_size = conv->operand(0)->shape().dimensions( + dim_numbers.input_batch_dimension()); + return conv->batch_group_count() != input_batch_size; + }; + + pipeline.AddPass( + batch_group_cost_model, + /*convert_batch_groups_only=*/true, + /*canonicalize_depthwise_filter=*/false); + auto cost_model = [](HloInstruction* conv) { // We need a cost model for GPUs. Currently, do nothing. return false; }; - pipeline.AddPass(); + pipeline.AddPass(cost_model); // Expand the sort op to support stable sorting if required. pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ec6a97e928a..4ce34ea4585 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1720,7 +1720,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (batch_group_count > 1 && kernel_output_features != batch_group_count) { + if (batch_group_count > 1 && + kernel_output_features % batch_group_count != 0) { return InvalidArgument( "Expected output feature dimension size (value %d) to be equal to " "batch group count %d; got (%s, %s)\n" @@ -1759,7 +1760,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dnums.DebugString()); } - if (input_batch % batch_group_count > 0) { + if (input_batch % batch_group_count != 0) { return InvalidArgument( "Expected input batch dimension (value %d) to be divisible by " "batch_group_count (value %d); " @@ -1793,6 +1794,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::vector dimensions(num_dims); dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count; dimensions[dnums.output_feature_dimension()] = kernel_output_features; + + if (batch_group_count > 1) { + dimensions[dnums.output_batch_dimension()] = + kernel_output_features / batch_group_count; + dimensions[dnums.output_feature_dimension()] = batch_group_count; + } + for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i);