Make comments consistent on GroupedFilterShapeForDepthwiseConvolution

PiperOrigin-RevId: 323704723
Change-Id: I00d6fb12ea053a6a8826a976833d752718980e34
This commit is contained in:
Sean Silva 2020-07-28 19:40:25 -07:00 committed by TensorFlower Gardener
parent 966bbfa4c5
commit d684ef2f40

View File

@ -44,7 +44,7 @@ namespace tensorflow {
namespace {
// Returns the expanded size of a filter used for depthwise convolution.
// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
// If `shape` is [H, W, ..., M, N] returns [H, W, ..., 1, M*N].
xla::Shape GroupedFilterShapeForDepthwiseConvolution(
const xla::Shape& filter_shape) {
int64 input_feature_dim = filter_shape.dimensions_size() - 2;
@ -52,7 +52,7 @@ xla::Shape GroupedFilterShapeForDepthwiseConvolution(
int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
int64 input_feature = filter_shape.dimensions(input_feature_dim);
// Create a [H, W, ..., 1, N*M] reshape of the filter.
// Create a [H, W, ..., 1, M*N] reshape of the filter.
xla::Shape grouped_filter_shape = filter_shape;
grouped_filter_shape.set_dimensions(input_feature_dim, 1);
grouped_filter_shape.set_dimensions(output_feature_dim,