Remove dead helper function TransposeInputForGroupConvolutionBackpropFilter

PiperOrigin-RevId: 352078583
Change-Id: I5f66d4db72718761364d7ee9336dbd27afe3801d
This commit is contained in:
Smit Hinsu 2021-01-15 13:57:42 -08:00 committed by TensorFlower Gardener
parent 1c729468d6
commit ed73ad7d46

View File

@ -85,32 +85,6 @@ xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput(
return result;
}
// Returns the transposed input for use in BackpropFilter of group convolution.
xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter(
xla::XlaOp input, const xla::Shape& input_shape, int64 num_groups,
int batch_dim, int depth_dim) {
// 1. Reshape the depth_dim C into [G, C/G]
int num_dims = input_shape.dimensions_size();
std::vector<int64> reshape_dims = xla::SpanToVector(input_shape.dimensions());
reshape_dims[depth_dim] = reshape_dims[depth_dim] / num_groups;
reshape_dims.insert(reshape_dims.begin() + depth_dim, num_groups);
xla::XlaOp result = xla::Reshape(input, reshape_dims);
// 2. Transpose G to the axis before N, e.g.: [G, N, H, W, C/G]
std::vector<int64> transpose_dims(num_dims + 1);
std::iota(transpose_dims.begin(), transpose_dims.end(),
0); // e.g.: [0, 1, 2, 3, 4] -> [N, H, W, G, C/G]
transpose_dims.erase(transpose_dims.begin() + depth_dim);
transpose_dims.insert(
transpose_dims.begin() + batch_dim,
depth_dim); // e.g.: [3, 0, 1, 2, 4] -> [G, N, H, W, C/G]
result = xla::Transpose(result, transpose_dims);
// 3. Merge [G, N] to [G*N]
result = xla::Collapse(result, {batch_dim, batch_dim + 1});
return result;
}
// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
// build a depthwise convolution.
xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,