Make comments consistent on GroupedFilterShapeForDepthwiseConvolution
PiperOrigin-RevId: 323704723 Change-Id: I00d6fb12ea053a6a8826a976833d752718980e34
This commit is contained in:
parent
966bbfa4c5
commit
d684ef2f40
@ -44,7 +44,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Returns the expanded size of a filter used for depthwise convolution.
|
||||
// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
|
||||
// If `shape` is [H, W, ..., M, N] returns [H, W, ..., 1, M*N].
|
||||
xla::Shape GroupedFilterShapeForDepthwiseConvolution(
|
||||
const xla::Shape& filter_shape) {
|
||||
int64 input_feature_dim = filter_shape.dimensions_size() - 2;
|
||||
@ -52,7 +52,7 @@ xla::Shape GroupedFilterShapeForDepthwiseConvolution(
|
||||
int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
|
||||
int64 input_feature = filter_shape.dimensions(input_feature_dim);
|
||||
|
||||
// Create a [H, W, ..., 1, N*M] reshape of the filter.
|
||||
// Create a [H, W, ..., 1, M*N] reshape of the filter.
|
||||
xla::Shape grouped_filter_shape = filter_shape;
|
||||
grouped_filter_shape.set_dimensions(input_feature_dim, 1);
|
||||
grouped_filter_shape.set_dimensions(output_feature_dim,
|
||||
|
Loading…
x
Reference in New Issue
Block a user