Remove dead helper function TransposeInputForGroupConvolutionBackpropFilter
PiperOrigin-RevId: 352078583 Change-Id: I5f66d4db72718761364d7ee9336dbd27afe3801d
This commit is contained in:
parent
1c729468d6
commit
ed73ad7d46
@ -85,32 +85,6 @@ xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput(
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the transposed input for use in BackpropFilter of group convolution.
|
||||
xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter(
|
||||
xla::XlaOp input, const xla::Shape& input_shape, int64 num_groups,
|
||||
int batch_dim, int depth_dim) {
|
||||
// 1. Reshape the depth_dim C into [G, C/G]
|
||||
int num_dims = input_shape.dimensions_size();
|
||||
std::vector<int64> reshape_dims = xla::SpanToVector(input_shape.dimensions());
|
||||
reshape_dims[depth_dim] = reshape_dims[depth_dim] / num_groups;
|
||||
reshape_dims.insert(reshape_dims.begin() + depth_dim, num_groups);
|
||||
xla::XlaOp result = xla::Reshape(input, reshape_dims);
|
||||
|
||||
// 2. Transpose G to the axis before N, e.g.: [G, N, H, W, C/G]
|
||||
std::vector<int64> transpose_dims(num_dims + 1);
|
||||
std::iota(transpose_dims.begin(), transpose_dims.end(),
|
||||
0); // e.g.: [0, 1, 2, 3, 4] -> [N, H, W, G, C/G]
|
||||
transpose_dims.erase(transpose_dims.begin() + depth_dim);
|
||||
transpose_dims.insert(
|
||||
transpose_dims.begin() + batch_dim,
|
||||
depth_dim); // e.g.: [3, 0, 1, 2, 4] -> [G, N, H, W, C/G]
|
||||
result = xla::Transpose(result, transpose_dims);
|
||||
|
||||
// 3. Merge [G, N] to [G*N]
|
||||
result = xla::Collapse(result, {batch_dim, batch_dim + 1});
|
||||
return result;
|
||||
}
|
||||
|
||||
// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
|
||||
// build a depthwise convolution.
|
||||
xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
|
||||
|
Loading…
x
Reference in New Issue
Block a user