Added int32 supports for Conv2D and Conv2DBackpropInput ops.

PiperOrigin-RevId: 275947647
Change-Id: I53be4e110b8e369c3a17b793015e1713c4d5f5fb
This commit is contained in:
Sung Jin Hwang 2019-10-21 16:18:24 -07:00 committed by TensorFlower Gardener
parent 807cf30585
commit 9d08b6bb4f
7 changed files with 364 additions and 137 deletions

View File

@ -102,6 +102,86 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
}
};
template <typename Device, typename T>
struct SpatialConvolutionBackwardInputFunc {
void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
typename TTypes<T, 4>::ConstTensor filter,
typename TTypes<T, 4>::ConstTensor output_backward,
Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride,
Eigen::DenseIndex col_dilation,
Eigen::DenseIndex row_dilation) {
input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
filter, output_backward, input_backward.dimension(2),
input_backward.dimension(1), col_stride, row_stride, col_dilation,
row_dilation);
}
};
// GPU version requires all tensors to be indexable by int32.
template <typename T>
struct SpatialConvolutionBackwardInputFunc<Eigen::GpuDevice, T> {
void operator()(const Eigen::GpuDevice& d,
typename TTypes<T, 4>::Tensor input_backward,
typename TTypes<T, 4>::ConstTensor filter,
typename TTypes<T, 4>::ConstTensor output_backward,
Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride,
Eigen::DenseIndex col_dilation,
Eigen::DenseIndex row_dilation) {
To32Bit(input_backward).device(d) = Eigen::SpatialConvolutionBackwardInput(
To32Bit(filter), To32Bit(output_backward), input_backward.dimension(2),
input_backward.dimension(1), col_stride, row_stride, col_dilation,
row_dilation);
}
};
template <typename Device, typename T>
struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc {
void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
typename TTypes<T, 4>::ConstTensor filter,
typename TTypes<T, 4>::ConstTensor output_backward,
Eigen::DenseIndex padded_cols, Eigen::DenseIndex padded_rows,
Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride,
Eigen::DenseIndex col_dilation,
Eigen::DenseIndex row_dilation, Eigen::DenseIndex pad_left,
Eigen::DenseIndex pad_top) {
// We have to slice the result of a spatial convolution backward
// input, before assigning it to the `input_backward` to remove padding.
//
// TODO(ezhulenev): Pass explicit paddings to Eigen and do not materialize
// intermediate result in memory before slicing.
input_backward.device(d) =
Eigen::SpatialConvolutionBackwardInput(
filter, output_backward, padded_cols, padded_rows, col_stride,
row_stride, col_dilation, row_dilation)
.eval()
.slice(Eigen::DSizes<Eigen::DenseIndex, 4>{0, pad_left, pad_top, 0},
input_backward.dimensions());
}
};
// GPU version requires all tensors to be indexable by int32.
template <typename T>
struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc<Eigen::GpuDevice,
T> {
void operator()(const Eigen::GpuDevice& d,
typename TTypes<T, 4>::Tensor input_backward,
typename TTypes<T, 4>::ConstTensor filter,
typename TTypes<T, 4>::ConstTensor output_backward,
Eigen::DenseIndex padded_cols, Eigen::DenseIndex padded_rows,
Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride,
Eigen::DenseIndex col_dilation,
Eigen::DenseIndex row_dilation, Eigen::DenseIndex pad_left,
Eigen::DenseIndex pad_top) {
To32Bit(input_backward).device(d) =
Eigen::SpatialConvolutionBackwardInput(
To32Bit(filter), To32Bit(output_backward), padded_cols, padded_rows,
col_stride, row_stride, col_dilation, row_dilation)
.eval()
.slice(Eigen::DSizes<Eigen::DenseIndex, 4>{0, pad_left, pad_top, 0},
input_backward.dimensions());
}
};
// TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h.
// My initial attempt to do this compiled but failed in the pytest
// due to a swigdeps error.

View File

@ -30,7 +30,14 @@ namespace tensorflow {
namespace functor {
// For 2d ops.
template struct PadInput<Eigen::GpuDevice, int, int, 4>;
template struct SpatialConvolution<Eigen::GpuDevice, int32>;
template struct MatMulConvFunctor<Eigen::GpuDevice, int32>;
template struct TransformFilter<Eigen::GpuDevice, int32, int, 4>;
template struct PadInput<Eigen::GpuDevice, int32, int, 4>;
template struct SpatialConvolutionBackwardInputFunc<Eigen::GpuDevice, int32>;
template struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
Eigen::GpuDevice, int32>;
} // namespace functor
} // namespace tensorflow

View File

@ -19,9 +19,11 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <limits>
#include <vector>
#include "absl/base/dynamic_annotations.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -59,8 +61,12 @@ limitations under the License.
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
namespace {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
// Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
// order (height, width, depth), constructed from patches in 'col_data', which
// is required to be in storage order (out_height * out_width, filter_height,
@ -97,16 +103,10 @@ void Col2im(const T* col_data, const int depth, const int height,
}
}
} // namespace
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU.
template <typename T>
struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU
// and GPU (for int32 only).
template <typename Device, typename T>
struct LaunchConv2DBackpropInputOpImpl {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
int row_dilation, int col_dilation, int row_stride,
@ -157,7 +157,21 @@ struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
&padding_right));
DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
if (std::is_same<Device, GPUDevice>::value) {
int64 size = 1;
#define REQUIRES_32BIT(x) \
size *= x; \
OP_REQUIRES(ctx, \
FastBoundsCheck(x, std::numeric_limits<int32>::max()) && \
FastBoundsCheck(size, std::numeric_limits<int32>::max()), \
errors::InvalidArgument("Tensor too large"))
REQUIRES_32BIT(in_backprop->dim_size(0));
REQUIRES_32BIT(in_backprop->dim_size(1) + padding_top + padding_bottom);
REQUIRES_32BIT(in_backprop->dim_size(2) + padding_left + padding_right);
REQUIRES_32BIT(in_backprop->dim_size(3));
#undef REQUIRES_32BIT
}
auto in_backprop_t = in_backprop->tensor<T, 4>();
auto out_backprop_t = out_backprop.tensor<T, 4>();
@ -170,32 +184,58 @@ struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
if (padding != EXPLICIT) {
// If padding was not explicitly defined, Eigen spatial convolution
// backward input will infer correct forward paddings from input tensors.
in_backprop_t.device(d) = Eigen::SpatialConvolutionBackwardInput(
filter_t, out_backprop_t, in_backprop_t.dimension(2),
in_backprop_t.dimension(1), col_stride, row_stride, col_dilation,
row_dilation);
functor::SpatialConvolutionBackwardInputFunc<Device, T>()(
ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
col_stride, row_stride, col_dilation, row_dilation);
} else {
// Otherwise we have to slice the result of a spatial convolution backward
// input, before assigning it to the `in_backprop` to remove padding.
using Offsets = Eigen::DSizes<Eigen::Index, 4>;
// TODO(ezhulenev): Pass explicit paddings to Eigen and do not materialize
// intermediate result in memory before slicing.
in_backprop_t.device(d) =
Eigen::SpatialConvolutionBackwardInput(
filter_t, out_backprop_t,
in_backprop_t.dimension(2) + (padding_left + padding_right),
in_backprop_t.dimension(1) + (padding_top + padding_bottom),
col_stride, row_stride, col_dilation, row_dilation)
.eval()
.slice(Offsets(0, padding_top, padding_left, 0),
/*sizes=*/in_backprop_t.dimensions());
functor::SpatialConvolutionBackwardInputWithExplicitPaddingFunc<Device,
T>()(
ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
in_backprop_t.dimension(2) + (padding_left + padding_right),
in_backprop_t.dimension(1) + (padding_top + padding_bottom),
col_stride, row_stride, col_dilation, row_dilation, padding_top,
padding_left);
}
}
};
} // namespace
// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU.
template <typename T>
struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
int row_dilation, int col_dilation, int row_stride,
int col_stride, const Padding& padding,
const std::vector<int64>& explicit_paddings,
Tensor* in_backprop, TensorFormat data_format) {
LaunchConv2DBackpropInputOpImpl<CPUDevice, T> launcher;
launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
row_dilation, col_dilation, row_stride, col_stride, padding,
explicit_paddings, in_backprop, data_format);
}
};
#ifdef GOOGLE_CUDA
// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on GPU
// for int32 inputs.
template <>
struct LaunchConv2DBackpropInputOp<GPUDevice, int32> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
int row_dilation, int col_dilation, int row_stride,
int col_stride, const Padding& padding,
const std::vector<int64>& explicit_paddings,
Tensor* in_backprop, TensorFormat data_format) {
LaunchConv2DBackpropInputOpImpl<GPUDevice, int32> launcher;
launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
row_dilation, col_dilation, row_stride, col_stride, padding,
explicit_paddings, in_backprop, data_format);
}
};
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
template <typename Device, class T>
struct LaunchXsmmBackwardInputConvolution {
@ -389,16 +429,19 @@ class Conv2DBackpropInputOp : public OpKernel {
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
if (std::is_same<Device, CPUDevice>::value) {
OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
errors::InvalidArgument("Conv2DBackpropInputOp [CPU] "
"only supports NHWC data format."));
if (std::is_same<Device, CPUDevice>::value ||
std::is_same<T, int32>::value) {
OP_REQUIRES(
context, data_format_ == FORMAT_NHWC,
errors::InvalidArgument("Conv2DBackpropInputOp [CPU or GPU(int32)] "
"only supports NHWC data format."));
// TODO(yangzihao): Add a CPU implementation for dilated convolution.
OP_REQUIRES(
context, (dilation_h == 1 && dilation_w == 1),
errors::InvalidArgument("Conv2DBackpropInputOp [CPU] not yet support "
"dilation rates larger than 1."));
errors::InvalidArgument(
"Conv2DBackpropInputOp [CPU or GPU(int32)] not yet support "
"dilation rates larger than 1."));
}
}
@ -761,6 +804,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
TF_CALL_half(REGISTER_CPU_KERNELS);
TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
TF_CALL_int32(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#undef DEFAULT_CPU_OP
@ -1259,6 +1303,28 @@ DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
template <>
void SpatialConvolutionBackwardInputFunc<GPUDevice, int32>::operator()(
const GPUDevice&, typename TTypes<int32, 4>::Tensor,
typename TTypes<int32, 4>::ConstTensor,
typename TTypes<int32, 4>::ConstTensor, Eigen::DenseIndex,
Eigen::DenseIndex, Eigen::DenseIndex, Eigen::DenseIndex);
extern template struct SpatialConvolutionBackwardInputFunc<GPUDevice, int32>;
template <>
void SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
GPUDevice, int32>::operator()(const GPUDevice&,
typename TTypes<int32, 4>::Tensor,
typename TTypes<int32, 4>::ConstTensor,
typename TTypes<int32, 4>::ConstTensor,
Eigen::DenseIndex, Eigen::DenseIndex,
Eigen::DenseIndex, Eigen::DenseIndex,
Eigen::DenseIndex, Eigen::DenseIndex,
Eigen::DenseIndex, Eigen::DenseIndex);
extern template struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
GPUDevice, int32>;
} // namespace functor
REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
@ -1276,6 +1342,11 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
.TypeConstraint<Eigen::half>("T")
.HostMemory("input_sizes"),
Conv2DBackpropInputOp<GPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.HostMemory("input_sizes"),
Conv2DBackpropInputOp<GPUDevice, int32>);
// To be used inside depthwise_conv_grad_op.cc.
// TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc.

View File

@ -176,6 +176,45 @@ struct LaunchConv2DOp<CPUDevice, T> {
}
};
#if GOOGLE_CUDA
template <>
struct LaunchConv2DOp<GPUDevice, int32> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& input, const Tensor& filter, int row_dilation,
int col_dilation, int row_stride, int col_stride,
const Padding& padding,
const std::vector<int64>& explicit_paddings, Tensor* output,
TensorFormat data_format) {
if (data_format != FORMAT_NHWC) {
ctx->SetStatus(
errors::Unimplemented("The Conv2D op currently only supports the "
"NHWC tensor format for integer types. "
"The op was given the format: ",
ToString(data_format)));
return;
}
const int64 in_depth = GetTensorDim(input, data_format, 'C');
OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
errors::Unimplemented(
"The Conv2D op currently does not support grouped "
"convolutions for integer types. A grouped convolution was "
"attempted to be run because the input depth of ",
in_depth, " does not match the filter input depth of ",
filter.dim_size(2)));
for (int64 explicit_padding : explicit_paddings) {
if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
ctx->SetStatus(errors::InvalidArgument("filter too large"));
return;
}
}
LaunchGeneric<GPUDevice, int32>()(
ctx, input, filter, row_stride, col_stride, row_dilation, col_dilation,
padding, explicit_paddings, output, data_format);
}
};
#endif // GOOGLE_CUDA
template <typename Device, typename T>
class LaunchDeepConvOp {
public:
@ -569,6 +608,7 @@ class Conv2DOp : public BinaryOp<T> {
TF_CALL_half(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
#endif // USE_GEMM_FOR_CONV
// To be used inside depthwise_conv_op.cc.
@ -1064,6 +1104,14 @@ namespace functor {
int col_stride, int row_dilation, int col_dilation, \
const Eigen::PaddingType& padding, \
const Eigen::NoOpOutputKernel& output_kernel); \
template <> \
void SpatialConvolution<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
typename TTypes<T, 4>::ConstTensor input, \
typename TTypes<T, 4>::ConstTensor filter, int row_stride, \
int col_stride, int row_dilation, int col_dilation, int padding_top, \
int padding_bottom, int padding_left, int padding_right, \
const Eigen::NoOpOutputKernel& output_kernel); \
extern template struct SpatialConvolution<GPUDevice, T>; \
template <> \
void MatMulConvFunctor<GPUDevice, T>::operator()( \
@ -1090,7 +1138,9 @@ namespace functor {
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(int32);
#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.
@ -1103,6 +1153,9 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER(
Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
Conv2DOp<GPUDevice, double>);
REGISTER_KERNEL_BUILDER(
Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
Conv2DOp<GPUDevice, int32>);
// To be used inside depthwise_conv_op.cc.
template struct LaunchConv2DOp<GPUDevice, float>;

View File

@ -330,7 +330,7 @@ REGISTER_OP("Conv2D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
.Attr("T: {half, bfloat16, float, double}")
.Attr("T: {half, bfloat16, float, double, int32}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrStringWithExplicit())
@ -344,7 +344,7 @@ REGISTER_OP("Conv2DBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
.Attr("T: {half, bfloat16, float, double}")
.Attr("T: {half, bfloat16, float, double, int32}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrStringWithExplicit())

View File

@ -38,122 +38,130 @@ class Conv2DTransposeTest(test.TestCase):
def testConv2DTransposeSingleStride(self):
with self.cached_session():
strides = [1, 1, 1, 1]
for dtype in (dtypes.float32, dtypes.int32):
strides = [1, 1, 1, 1]
# Input, output: [batch, height, width, depth]
x_shape = [2, 6, 4, 3]
y_shape = [2, 6, 4, 2]
# Input, output: [batch, height, width, depth]
x_shape = [2, 6, 4, 3]
y_shape = [2, 6, 4, 2]
# Filter: [kernel_height, kernel_width, output_depth, input_depth]
f_shape = [3, 3, 2, 3]
# Filter: [kernel_height, kernel_width, output_depth, input_depth]
f_shape = [3, 3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv2d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
value = self.evaluate(output)
x = constant_op.constant(1, shape=x_shape, name="x", dtype=dtype)
f = constant_op.constant(1, shape=f_shape, name="filter", dtype=dtype)
output = nn_ops.conv2d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
value = self.evaluate(output)
# We count the number of cells being added at the locations in the output.
# At the center, #cells=kernel_height * kernel_width
# At the corners, #cells=ceil(kernel_height/2) * ceil(kernel_width/2)
# At the borders, #cells=ceil(kernel_height/2)*kernel_width or
# kernel_height * ceil(kernel_width/2)
# We count the number of cells being added at the locations in the
# output.
# At the center, #cells=kernel_height * kernel_width
# At the corners, #cells=ceil(kernel_height/2) * ceil(kernel_width/2)
# At the borders, #cells=ceil(kernel_height/2)*kernel_width or
# kernel_height * ceil(kernel_width/2)
for n in xrange(x_shape[0]):
for k in xrange(f_shape[2]):
for w in xrange(y_shape[2]):
for h in xrange(y_shape[1]):
target = 4 * 3.0
h_in = h > 0 and h < y_shape[1] - 1
w_in = w > 0 and w < y_shape[2] - 1
if h_in and w_in:
target += 5 * 3.0
elif h_in or w_in:
target += 2 * 3.0
self.assertAllClose(target, value[n, h, w, k])
for n in xrange(x_shape[0]):
for k in xrange(f_shape[2]):
for w in xrange(y_shape[2]):
for h in xrange(y_shape[1]):
target = 4 * 3
h_in = h > 0 and h < y_shape[1] - 1
w_in = w > 0 and w < y_shape[2] - 1
if h_in and w_in:
target += 5 * 3
elif h_in or w_in:
target += 2 * 3
if dtype.is_integer:
self.assertAllEqual(target, value[n, h, w, k])
else:
self.assertAllClose(target, value[n, h, w, k])
def testConv2DTransposeSame(self):
with self.cached_session():
strides = [1, 2, 2, 1]
for dtype in (dtypes.float32, dtypes.int32):
strides = [1, 2, 2, 1]
# Input, output: [batch, height, width, depth]
x_shape = [2, 6, 4, 3]
y_shape = [2, 12, 8, 2]
# Input, output: [batch, height, width, depth]
x_shape = [2, 6, 4, 3]
y_shape = [2, 12, 8, 2]
# Filter: [kernel_height, kernel_width, output_depth, input_depth]
f_shape = [3, 3, 2, 3]
# Filter: [kernel_height, kernel_width, output_depth, input_depth]
f_shape = [3, 3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv2d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
value = self.evaluate(output)
x = constant_op.constant(1, shape=x_shape, name="x", dtype=dtype)
f = constant_op.constant(1, shape=f_shape, name="filter", dtype=dtype)
output = nn_ops.conv2d_transpose(
x, f, y_shape, strides=strides, padding="SAME")
value = self.evaluate(output)
for n in xrange(x_shape[0]):
for k in xrange(f_shape[2]):
for w in xrange(y_shape[2]):
for h in xrange(y_shape[1]):
target = 3.0
# We add a case for locations divisible by the stride.
h_in = h % strides[1] == 0 and h > 0 and h < y_shape[1] - 1
w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1
if h_in and w_in:
target += 9.0
elif h_in or w_in:
target += 3.0
self.assertAllClose(target, value[n, h, w, k])
for n in xrange(x_shape[0]):
for k in xrange(f_shape[2]):
for w in xrange(y_shape[2]):
for h in xrange(y_shape[1]):
target = 3
# We add a case for locations divisible by the stride.
h_in = h % strides[1] == 0 and h > 0 and h < y_shape[1] - 1
w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1
if h_in and w_in:
target += 9
elif h_in or w_in:
target += 3
if dtype.is_integer:
self.assertAllEqual(target, value[n, h, w, k])
else:
self.assertAllClose(target, value[n, h, w, k])
def testConv2DTransposeValid(self):
with self.cached_session():
strides = [1, 2, 2, 1]
for dtype in (dtypes.float32, dtypes.int32):
strides = [1, 2, 2, 1]
# Input, output: [batch, height, width, depth]
x_shape = [2, 6, 4, 3]
y_shape = [2, 13, 9, 2]
# Input, output: [batch, height, width, depth]
x_shape = [2, 6, 4, 3]
y_shape = [2, 13, 9, 2]
# Filter: [kernel_height, kernel_width, output_depth, input_depth]
f_shape = [3, 3, 2, 3]
# Filter: [kernel_height, kernel_width, output_depth, input_depth]
f_shape = [3, 3, 2, 3]
x = constant_op.constant(
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
f = constant_op.constant(
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv2d_transpose(
x, f, y_shape, strides=strides, padding="VALID")
value = self.evaluate(output)
x = constant_op.constant(1, shape=x_shape, name="x", dtype=dtype)
f = constant_op.constant(1, shape=f_shape, name="filter", dtype=dtype)
output = nn_ops.conv2d_transpose(
x, f, y_shape, strides=strides, padding="VALID")
value = self.evaluate(output)
cache_values = np.zeros(y_shape, dtype=np.float32)
cache_values = np.zeros(y_shape, dtype=np.float32)
# The amount of padding added
pad = 1
# The amount of padding added
pad = 1
for n in xrange(x_shape[0]):
for k in xrange(f_shape[2]):
for w in xrange(pad, y_shape[2] - pad):
for h in xrange(pad, y_shape[1] - pad):
target = 3.0
# We add a case for locations divisible by the stride.
h_in = h % strides[1] == 0 and h > pad and h < y_shape[
1] - 1 - pad
w_in = w % strides[2] == 0 and w > pad and w < y_shape[
2] - 1 - pad
if h_in and w_in:
target += 9.0
elif h_in or w_in:
target += 3.0
cache_values[n, h, w, k] = target
for n in xrange(x_shape[0]):
for k in xrange(f_shape[2]):
for w in xrange(pad, y_shape[2] - pad):
for h in xrange(pad, y_shape[1] - pad):
target = 3
# We add a case for locations divisible by the stride.
h_in = h % strides[1] == 0 and h > pad and h < y_shape[
1] - 1 - pad
w_in = w % strides[2] == 0 and w > pad and w < y_shape[
2] - 1 - pad
if h_in and w_in:
target += 9
elif h_in or w_in:
target += 3
cache_values[n, h, w, k] = target
# copy values in the border
cache_values[n, :, 0, k] = cache_values[n, :, 1, k]
cache_values[n, :, -1, k] = cache_values[n, :, -2, k]
cache_values[n, 0, :, k] = cache_values[n, 1, :, k]
cache_values[n, -1, :, k] = cache_values[n, -2, :, k]
# copy values in the border
cache_values[n, :, 0, k] = cache_values[n, :, 1, k]
cache_values[n, :, -1, k] = cache_values[n, :, -2, k]
cache_values[n, 0, :, k] = cache_values[n, 1, :, k]
cache_values[n, -1, :, k] = cache_values[n, -2, :, k]
self.assertAllClose(cache_values, value)
if dtype.is_integer:
self.assertAllEqual(cache_values, value)
else:
self.assertAllClose(cache_values, value)
@test_util.run_deprecated_v1
def testGradient(self):

View File

@ -220,6 +220,7 @@ class Conv2DTest(test.TestCase):
strides=strides,
padding=padding,
data_format=data_format)
self.assertEqual(conv.dtype, dtype)
if data_format == "NCHW":
conv = test_util.NCHWToNHWC(conv)
@ -336,7 +337,10 @@ class Conv2DTest(test.TestCase):
for (data_format, use_gpu) in GetTestConfigs():
if gpu_only and not use_gpu:
continue
for dtype in self._DtypesToTest(use_gpu):
dtypes_to_test = self._DtypesToTest(use_gpu)
if not test_grappler_layout_optimizer and data_format == "NHWC":
dtypes_to_test.append(dtypes.int32)
for dtype in dtypes_to_test:
result = self._SetupValuesForDevice(
tensor_in_sizes,
filter_in_sizes,
@ -358,9 +362,13 @@ class Conv2DTest(test.TestCase):
tf_logging.debug("expected = %s", expected)
tf_logging.debug("actual = %s", value)
tol_to_use = fp16_tol if value.dtype == np.float16 else tol
self.assertAllClose(expected, np.ravel(value), atol=tol_to_use,
rtol=tol_to_use)
if np.issubdtype(value.dtype, np.integer):
self.assertAllEqual(expected, np.ravel(value))
else:
self.assertAllClose(expected, np.ravel(value), atol=tol_to_use,
rtol=tol_to_use)
self.assertShapeEqual(value, conv)
self.assertEqual(value.dtype, conv.dtype.as_numpy_dtype)
def _VerifyExplicitPaddings(self,
tensor_in_sizes,