Move conv grad shape related utils to a separate lib

This allows targets to use the shape utils without depending on the kernels.

* Switch to the new library in targets that don't use the kernels
* Include the new header in the conv kernels

PiperOrigin-RevId: 277570525
Change-Id: I133615d6682d17bb5f0cf49e4171dcbf5529120b
This commit is contained in:
Smit Hinsu 2019-10-30 13:17:40 -07:00 committed by TensorFlower Gardener
parent f04053d46d
commit 5cedcf31d6
11 changed files with 123 additions and 80 deletions

View File

@ -109,7 +109,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",

View File

@ -42,7 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -1430,7 +1430,7 @@ class ConvertConv2DBackpropInputOp
int64_t filter_in_depth = filter_shape[num_spatial_dims];
int64_t feature_group_count = in_depth / filter_in_depth;
// Reuse dimension computation logic from conv_grad_ops.cc.
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
tensorflow::ConvBackpropDimensions dims;
if (!tensorflow::ConvBackpropComputeDimensionsV2(
"", num_spatial_dims, ToTensorShape<int>(input_shape),
@ -1569,7 +1569,7 @@ class ConvertConv2DBackpropFilterOp
llvm::to_vector<4>(filter_shape_attr.getValues<int32_t>());
if (filter_shape.size() != num_dims) return matchFailure();
// Reuse dimension computation logic from conv_grad_ops.cc.
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
tensorflow::ConvBackpropDimensions dims;
if (!tensorflow::ConvBackpropComputeDimensionsV2(
"", num_spatial_dims, ToTensorShape<int64_t>(input_shape),

View File

@ -255,7 +255,7 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@com_google_absl//absl/types:span",
],
)

View File

@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -404,7 +404,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
xla::Shape expanded_filter_shape =
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_ops.cc.
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
ConvBackpropDimensions dims;
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
@ -413,7 +413,7 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
// The input gradients are computed by a convolution of the output
// gradients and the filter, with some appropriate padding. See the
// comment at the top of conv_grad_ops.h for details.
// comment at the top of conv_grad_shape_utils.h for details.
xla::ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(batch_dim);
@ -487,11 +487,11 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
const xla::Shape expanded_filter_shape =
attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_ops.cc.
// Reuse dimension computation logic from conv_grad_shape_utils.cc.
ConvBackpropDimensions dims;
// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
// See the comment at the top of conv_grad_ops.h for details.
// See the comment at the top of conv_grad_shape_utils.h for details.
xla::ConvolutionDimensionNumbers dnums;
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(

View File

@ -4313,7 +4313,6 @@ tf_kernel_library(
srcs = [
"conv_grad_filter_ops.cc",
"conv_grad_input_ops.cc",
"conv_grad_ops.cc",
"conv_grad_ops_3d.cc",
"deep_conv2d.cc",
] + select({
@ -4339,6 +4338,7 @@ tf_kernel_library(
}),
prefix = "conv_ops",
deps = [
":conv_grad_shape_utils",
":bounds_check",
":conv_2d",
":conv_3d",
@ -4372,6 +4372,24 @@ tf_kernel_library(
]),
)
cc_library(
name = "conv_grad_shape_utils",
srcs = [
"conv_grad_shape_utils.cc",
],
hdrs = [
"conv_grad_shape_utils.h",
],
deps = [
":ops_util",
"//tensorflow/core:framework",
"//tensorflow/core/lib/core:errors",
"//tensorflow/core/lib/core:stringpiece",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:macros",
],
)
tf_kernel_library(
name = "depthwise_conv_op",
srcs = ["depthwise_conv_op.cc"],
@ -6363,8 +6381,9 @@ filegroup(
"conv_2d.h",
"conv_grad_filter_ops.cc",
"conv_grad_input_ops.cc",
"conv_grad_ops.cc",
"conv_grad_ops.h",
"conv_grad_shape_utils.h",
"conv_grad_shape_utils.cc",
"conv_ops.cc",
"conv_ops_3d.cc",
"conv_ops_fused_double.cc",

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/kernels/fill_functor.h"
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
#include "tensorflow/core/kernels/xsmm_conv2d.h"

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
#include "tensorflow/core/kernels/xsmm_conv2d.h"
#endif

View File

@ -161,8 +161,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -212,66 +210,6 @@ struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> {
Tensor* filter_backprop, TensorFormat data_format);
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Information about a single spatial dimension for a convolution
// backpropagation.
struct ConvBackpropSpatialDimension {
int64 input_size;
int64 filter_size;
int64 output_size;
int64 stride;
int64 dilation;
// Output size after scaling by the stride.
int64 expanded_output_size;
// Number of padding elements to be added before/after this dimension of
// the input when computing Conv?DBackpropInput.
int64 pad_before, pad_after;
};
// Computed dimensions for a backwards convolution.
struct ConvBackpropDimensions {
// Information about each spatial dimension.
gtl::InlinedVector<ConvBackpropSpatialDimension, 3> spatial_dims;
// Batch size.
int64 batch_size;
// Input and output feature depth.
int64 in_depth, out_depth;
// Convenience access methods for spatial dimensions properties.
int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
int64 stride(int dim) const { return spatial_dims[dim].stride; }
int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
// Compute padding for the given spatial dimension.
int SpatialPadding(const Padding& padding, int dim) const;
};
// Common code between implementations of Conv?DBackpropInput and
// Conv?DBackpropFilter. Verifies that the dimensions all match, and computes
// sizes/padding for the spatial dimensions. Does not support explicit padding.
Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& out_backprop_shape,
const std::vector<int32>& strides,
Padding padding, TensorFormat data_format,
ConvBackpropDimensions* dims);
// The V2 version computes the same outputs with arbitrary dilation rate and
// supports explicit padding.
// TODO(b/67112639): Merge V2 versions and the original versions eventually.
Status ConvBackpropComputeDimensionsV2(
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
Padding padding, absl::Span<const int64> explicit_paddings,
TensorFormat data_format, ConvBackpropDimensions* dims);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_3d.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -18,26 +18,21 @@ limitations under the License.
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include <algorithm>
#include <vector>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
namespace tensorflow {

View File

@ -0,0 +1,88 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
// Information about a single spatial dimension for a convolution
// backpropagation.
struct ConvBackpropSpatialDimension {
int64 input_size;
int64 filter_size;
int64 output_size;
int64 stride;
int64 dilation;
// Output size after scaling by the stride.
int64 expanded_output_size;
// Number of padding elements to be added before/after this dimension of
// the input when computing Conv?DBackpropInput.
int64 pad_before, pad_after;
};
// Computed dimensions for a backwards convolution.
struct ConvBackpropDimensions {
// Information about each spatial dimension.
gtl::InlinedVector<ConvBackpropSpatialDimension, 3> spatial_dims;
// Batch size.
int64 batch_size;
// Input and output feature depth.
int64 in_depth, out_depth;
// Convenience access methods for spatial dimensions properties.
int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
int64 stride(int dim) const { return spatial_dims[dim].stride; }
int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
// Compute padding for the given spatial dimension.
int SpatialPadding(const Padding& padding, int dim) const;
};
// Common code between implementations of Conv?DBackpropInput and
// Conv?DBackpropFilter. Verifies that the dimensions all match, and computes
// sizes/padding for the spatial dimensions. Does not support explicit padding.
Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& out_backprop_shape,
const std::vector<int32>& strides,
Padding padding, TensorFormat data_format,
ConvBackpropDimensions* dims);
// The V2 version computes the same outputs with arbitrary dilation rate and
// supports explicit padding.
// TODO(b/67112639): Merge V2 versions and the original versions eventually.
Status ConvBackpropComputeDimensionsV2(
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
Padding padding, absl::Span<const int64> explicit_paddings,
TensorFormat data_format, ConvBackpropDimensions* dims);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_