Move conv grad shape related utils to a separate lib
This allows targets to use the shape utils without depending on the kernels. * Switch to the new library in targets that don't use the kernels * Include the new header in the conv kernels PiperOrigin-RevId: 277570525 Change-Id: I133615d6682d17bb5f0cf49e4171dcbf5529120b
This commit is contained in:
parent
f04053d46d
commit
5cedcf31d6
@ -109,7 +109,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/kernels:conv_ops",
|
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
||||||
"@llvm//:support",
|
"@llvm//:support",
|
||||||
"@local_config_mlir//:Analysis",
|
"@local_config_mlir//:Analysis",
|
||||||
"@local_config_mlir//:IR",
|
"@local_config_mlir//:IR",
|
||||||
|
|||||||
@ -42,7 +42,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
@ -1430,7 +1430,7 @@ class ConvertConv2DBackpropInputOp
|
|||||||
int64_t filter_in_depth = filter_shape[num_spatial_dims];
|
int64_t filter_in_depth = filter_shape[num_spatial_dims];
|
||||||
int64_t feature_group_count = in_depth / filter_in_depth;
|
int64_t feature_group_count = in_depth / filter_in_depth;
|
||||||
|
|
||||||
// Reuse dimension computation logic from conv_grad_ops.cc.
|
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
|
||||||
tensorflow::ConvBackpropDimensions dims;
|
tensorflow::ConvBackpropDimensions dims;
|
||||||
if (!tensorflow::ConvBackpropComputeDimensionsV2(
|
if (!tensorflow::ConvBackpropComputeDimensionsV2(
|
||||||
"", num_spatial_dims, ToTensorShape<int>(input_shape),
|
"", num_spatial_dims, ToTensorShape<int>(input_shape),
|
||||||
@ -1569,7 +1569,7 @@ class ConvertConv2DBackpropFilterOp
|
|||||||
llvm::to_vector<4>(filter_shape_attr.getValues<int32_t>());
|
llvm::to_vector<4>(filter_shape_attr.getValues<int32_t>());
|
||||||
if (filter_shape.size() != num_dims) return matchFailure();
|
if (filter_shape.size() != num_dims) return matchFailure();
|
||||||
|
|
||||||
// Reuse dimension computation logic from conv_grad_ops.cc.
|
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
|
||||||
tensorflow::ConvBackpropDimensions dims;
|
tensorflow::ConvBackpropDimensions dims;
|
||||||
if (!tensorflow::ConvBackpropComputeDimensionsV2(
|
if (!tensorflow::ConvBackpropComputeDimensionsV2(
|
||||||
"", num_spatial_dims, ToTensorShape<int64_t>(input_shape),
|
"", num_spatial_dims, ToTensorShape<int64_t>(input_shape),
|
||||||
|
|||||||
@ -255,7 +255,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client/lib:constants",
|
"//tensorflow/compiler/xla/client/lib:constants",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_bounds_check",
|
"//tensorflow/core:framework_bounds_check",
|
||||||
"//tensorflow/core/kernels:conv_ops",
|
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -36,7 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_slice.h"
|
#include "tensorflow/core/framework/tensor_slice.h"
|
||||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
|||||||
xla::Shape expanded_filter_shape =
|
xla::Shape expanded_filter_shape =
|
||||||
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
|
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
|
||||||
: filter_shape;
|
: filter_shape;
|
||||||
// Reuse dimension computation logic from conv_grad_ops.cc.
|
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
|
||||||
ConvBackpropDimensions dims;
|
ConvBackpropDimensions dims;
|
||||||
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
|
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
|
||||||
type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
|
type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
|
||||||
@ -413,7 +413,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
|||||||
|
|
||||||
// The input gradients are computed by a convolution of the output
|
// The input gradients are computed by a convolution of the output
|
||||||
// gradients and the filter, with some appropriate padding. See the
|
// gradients and the filter, with some appropriate padding. See the
|
||||||
// comment at the top of conv_grad_ops.h for details.
|
// comment at the top of conv_grad_shape_utils.h for details.
|
||||||
|
|
||||||
xla::ConvolutionDimensionNumbers dnums;
|
xla::ConvolutionDimensionNumbers dnums;
|
||||||
dnums.set_input_batch_dimension(batch_dim);
|
dnums.set_input_batch_dimension(batch_dim);
|
||||||
@ -487,11 +487,11 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
|||||||
const xla::Shape expanded_filter_shape =
|
const xla::Shape expanded_filter_shape =
|
||||||
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
|
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
|
||||||
: filter_shape;
|
: filter_shape;
|
||||||
// Reuse dimension computation logic from conv_grad_ops.cc.
|
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
|
||||||
ConvBackpropDimensions dims;
|
ConvBackpropDimensions dims;
|
||||||
// The filter gradients are computed by a convolution of the input
|
// The filter gradients are computed by a convolution of the input
|
||||||
// activations and the output gradients, with some appropriate padding.
|
// activations and the output gradients, with some appropriate padding.
|
||||||
// See the comment at the top of conv_grad_ops.h for details.
|
// See the comment at the top of conv_grad_shape_utils.h for details.
|
||||||
xla::ConvolutionDimensionNumbers dnums;
|
xla::ConvolutionDimensionNumbers dnums;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
|
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
|
||||||
|
|||||||
@ -4313,7 +4313,6 @@ tf_kernel_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"conv_grad_filter_ops.cc",
|
"conv_grad_filter_ops.cc",
|
||||||
"conv_grad_input_ops.cc",
|
"conv_grad_input_ops.cc",
|
||||||
"conv_grad_ops.cc",
|
|
||||||
"conv_grad_ops_3d.cc",
|
"conv_grad_ops_3d.cc",
|
||||||
"deep_conv2d.cc",
|
"deep_conv2d.cc",
|
||||||
] + select({
|
] + select({
|
||||||
@ -4339,6 +4338,7 @@ tf_kernel_library(
|
|||||||
}),
|
}),
|
||||||
prefix = "conv_ops",
|
prefix = "conv_ops",
|
||||||
deps = [
|
deps = [
|
||||||
|
":conv_grad_shape_utils",
|
||||||
":bounds_check",
|
":bounds_check",
|
||||||
":conv_2d",
|
":conv_2d",
|
||||||
":conv_3d",
|
":conv_3d",
|
||||||
@ -4372,6 +4372,24 @@ tf_kernel_library(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "conv_grad_shape_utils",
|
||||||
|
srcs = [
|
||||||
|
"conv_grad_shape_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"conv_grad_shape_utils.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":ops_util",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/lib/core:errors",
|
||||||
|
"//tensorflow/core/lib/core:stringpiece",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
|
"//tensorflow/core/platform:macros",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "depthwise_conv_op",
|
name = "depthwise_conv_op",
|
||||||
srcs = ["depthwise_conv_op.cc"],
|
srcs = ["depthwise_conv_op.cc"],
|
||||||
@ -6363,8 +6381,9 @@ filegroup(
|
|||||||
"conv_2d.h",
|
"conv_2d.h",
|
||||||
"conv_grad_filter_ops.cc",
|
"conv_grad_filter_ops.cc",
|
||||||
"conv_grad_input_ops.cc",
|
"conv_grad_input_ops.cc",
|
||||||
"conv_grad_ops.cc",
|
|
||||||
"conv_grad_ops.h",
|
"conv_grad_ops.h",
|
||||||
|
"conv_grad_shape_utils.h",
|
||||||
|
"conv_grad_shape_utils.cc",
|
||||||
"conv_ops.cc",
|
"conv_ops.cc",
|
||||||
"conv_ops_3d.cc",
|
"conv_ops_3d.cc",
|
||||||
"conv_ops_fused_double.cc",
|
"conv_ops_fused_double.cc",
|
||||||
|
|||||||
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_slice.h"
|
#include "tensorflow/core/framework/tensor_slice.h"
|
||||||
#include "tensorflow/core/kernels/conv_2d.h"
|
#include "tensorflow/core/kernels/conv_2d.h"
|
||||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
|
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
|
||||||
#include "tensorflow/core/kernels/xsmm_conv2d.h"
|
#include "tensorflow/core/kernels/xsmm_conv2d.h"
|
||||||
|
|||||||
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_slice.h"
|
#include "tensorflow/core/framework/tensor_slice.h"
|
||||||
#include "tensorflow/core/kernels/conv_2d.h"
|
#include "tensorflow/core/kernels/conv_2d.h"
|
||||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||||
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
|
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
|
||||||
#include "tensorflow/core/kernels/xsmm_conv2d.h"
|
#include "tensorflow/core/kernels/xsmm_conv2d.h"
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -161,8 +161,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
@ -212,66 +210,6 @@ struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> {
|
|||||||
Tensor* filter_backprop, TensorFormat data_format);
|
Tensor* filter_backprop, TensorFormat data_format);
|
||||||
};
|
};
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
// Information about a single spatial dimension for a convolution
|
|
||||||
// backpropagation.
|
|
||||||
struct ConvBackpropSpatialDimension {
|
|
||||||
int64 input_size;
|
|
||||||
int64 filter_size;
|
|
||||||
int64 output_size;
|
|
||||||
int64 stride;
|
|
||||||
int64 dilation;
|
|
||||||
|
|
||||||
// Output size after scaling by the stride.
|
|
||||||
int64 expanded_output_size;
|
|
||||||
|
|
||||||
// Number of padding elements to be added before/after this dimension of
|
|
||||||
// the input when computing Conv?DBackpropInput.
|
|
||||||
int64 pad_before, pad_after;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Computed dimensions for a backwards convolution.
|
|
||||||
struct ConvBackpropDimensions {
|
|
||||||
// Information about each spatial dimension.
|
|
||||||
gtl::InlinedVector<ConvBackpropSpatialDimension, 3> spatial_dims;
|
|
||||||
|
|
||||||
// Batch size.
|
|
||||||
int64 batch_size;
|
|
||||||
|
|
||||||
// Input and output feature depth.
|
|
||||||
int64 in_depth, out_depth;
|
|
||||||
|
|
||||||
// Convenience access methods for spatial dimensions properties.
|
|
||||||
int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
|
|
||||||
int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
|
|
||||||
int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
|
|
||||||
int64 stride(int dim) const { return spatial_dims[dim].stride; }
|
|
||||||
int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
|
|
||||||
|
|
||||||
// Compute padding for the given spatial dimension.
|
|
||||||
int SpatialPadding(const Padding& padding, int dim) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Common code between implementations of Conv?DBackpropInput and
|
|
||||||
// Conv?DBackpropFilter. Verifies that the dimensions all match, and computes
|
|
||||||
// sizes/padding for the spatial dimensions. Does not support explicit padding.
|
|
||||||
Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
|
||||||
const TensorShape& input_shape,
|
|
||||||
const TensorShape& filter_shape,
|
|
||||||
const TensorShape& out_backprop_shape,
|
|
||||||
const std::vector<int32>& strides,
|
|
||||||
Padding padding, TensorFormat data_format,
|
|
||||||
ConvBackpropDimensions* dims);
|
|
||||||
|
|
||||||
// The V2 version computes the same outputs with arbitrary dilation rate and
|
|
||||||
// supports explicit padding.
|
|
||||||
// TODO(b/67112639): Merge V2 versions and the original versions eventually.
|
|
||||||
Status ConvBackpropComputeDimensionsV2(
|
|
||||||
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
|
|
||||||
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
|
|
||||||
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
|
|
||||||
Padding padding, absl::Span<const int64> explicit_paddings,
|
|
||||||
TensorFormat data_format, ConvBackpropDimensions* dims);
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_
|
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_
|
||||||
|
|||||||
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/conv_2d.h"
|
#include "tensorflow/core/kernels/conv_2d.h"
|
||||||
#include "tensorflow/core/kernels/conv_3d.h"
|
#include "tensorflow/core/kernels/conv_3d.h"
|
||||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|||||||
@ -18,26 +18,21 @@ limitations under the License.
|
|||||||
#define USE_EIGEN_TENSOR
|
#define USE_EIGEN_TENSOR
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_slice.h"
|
|
||||||
#include "tensorflow/core/kernels/conv_2d.h"
|
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
#include "tensorflow/core/util/use_cudnn.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
88
tensorflow/core/kernels/conv_grad_shape_utils.h
Normal file
88
tensorflow/core/kernels/conv_grad_shape_utils.h
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
#include "tensorflow/core/util/padding.h"
|
||||||
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
// Information about a single spatial dimension for a convolution
|
||||||
|
// backpropagation.
|
||||||
|
struct ConvBackpropSpatialDimension {
|
||||||
|
int64 input_size;
|
||||||
|
int64 filter_size;
|
||||||
|
int64 output_size;
|
||||||
|
int64 stride;
|
||||||
|
int64 dilation;
|
||||||
|
|
||||||
|
// Output size after scaling by the stride.
|
||||||
|
int64 expanded_output_size;
|
||||||
|
|
||||||
|
// Number of padding elements to be added before/after this dimension of
|
||||||
|
// the input when computing Conv?DBackpropInput.
|
||||||
|
int64 pad_before, pad_after;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Computed dimensions for a backwards convolution.
|
||||||
|
struct ConvBackpropDimensions {
|
||||||
|
// Information about each spatial dimension.
|
||||||
|
gtl::InlinedVector<ConvBackpropSpatialDimension, 3> spatial_dims;
|
||||||
|
|
||||||
|
// Batch size.
|
||||||
|
int64 batch_size;
|
||||||
|
|
||||||
|
// Input and output feature depth.
|
||||||
|
int64 in_depth, out_depth;
|
||||||
|
|
||||||
|
// Convenience access methods for spatial dimensions properties.
|
||||||
|
int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
|
||||||
|
int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
|
||||||
|
int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
|
||||||
|
int64 stride(int dim) const { return spatial_dims[dim].stride; }
|
||||||
|
int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
|
||||||
|
|
||||||
|
// Compute padding for the given spatial dimension.
|
||||||
|
int SpatialPadding(const Padding& padding, int dim) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Common code between implementations of Conv?DBackpropInput and
|
||||||
|
// Conv?DBackpropFilter. Verifies that the dimensions all match, and computes
|
||||||
|
// sizes/padding for the spatial dimensions. Does not support explicit padding.
|
||||||
|
Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
||||||
|
const TensorShape& input_shape,
|
||||||
|
const TensorShape& filter_shape,
|
||||||
|
const TensorShape& out_backprop_shape,
|
||||||
|
const std::vector<int32>& strides,
|
||||||
|
Padding padding, TensorFormat data_format,
|
||||||
|
ConvBackpropDimensions* dims);
|
||||||
|
|
||||||
|
// The V2 version computes the same outputs with arbitrary dilation rate and
|
||||||
|
// supports explicit padding.
|
||||||
|
// TODO(b/67112639): Merge V2 versions and the original versions eventually.
|
||||||
|
Status ConvBackpropComputeDimensionsV2(
|
||||||
|
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
|
||||||
|
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
|
||||||
|
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
|
||||||
|
Padding padding, absl::Span<const int64> explicit_paddings,
|
||||||
|
TensorFormat data_format, ConvBackpropDimensions* dims);
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_
|
||||||
Loading…
x
Reference in New Issue
Block a user