Add int64 padding support for MirrorPad (#13907)
* Add int64 padding support for MirrorPad This fix adds int64 padding support for `MirrorPad`. In the `array_ops.cc` the `MirrorPad`/`MirrorPadGrad` has been specified as supporting int64 padding. The related kernels does not have the int64 padding registered though. This fix adds the int64 padding support. This fix also adds additional test cases for coverage. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Update template for CPU and GPU support of int64 paddings. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int64 padding support for MirrorPad Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Put eigen header first like before, just in case.
This commit is contained in:
parent
690003cc01
commit
0d437c3beb
tensorflow
core/kernels
python/kernel_tests
@ -18,10 +18,10 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/mirror_pad_op.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -35,7 +35,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Tpaddings>
|
||||
class MirrorPadOp : public OpKernel {
|
||||
public:
|
||||
explicit MirrorPadOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
@ -82,10 +82,10 @@ class MirrorPadOp : public OpKernel {
|
||||
|
||||
// Compute the shape of the output tensor, and allocate it.
|
||||
TensorShape output_shape;
|
||||
TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>();
|
||||
typename TTypes<Tpaddings>::ConstMatrix paddings = in1.matrix<Tpaddings>();
|
||||
for (int d = 0; d < dims; ++d) {
|
||||
const int32 before = paddings(d, 0); // Pad before existing elements.
|
||||
const int32 after = paddings(d, 1); // Pad after existing elements.
|
||||
const Tpaddings before = paddings(d, 0); // Pad before existing elements.
|
||||
const Tpaddings after = paddings(d, 1); // Pad after existing elements.
|
||||
OP_REQUIRES(context, before >= 0 && after >= 0,
|
||||
errors::InvalidArgument("paddings must be non-negative: ",
|
||||
before, " ", after));
|
||||
@ -121,7 +121,7 @@ class MirrorPadOp : public OpKernel {
|
||||
|
||||
#define MIRROR_PAD_CASE(i) \
|
||||
case i: { \
|
||||
functor::MirrorPad<Device, T, i>()( \
|
||||
functor::MirrorPad<Device, T, Tpaddings, i>()( \
|
||||
context->eigen_device<Device>(), To32Bit(output->tensor<T, i>()), \
|
||||
To32Bit(in0.tensor<T, i>()), paddings, offset_); \
|
||||
break; \
|
||||
@ -152,20 +152,25 @@ using GpuDevice = Eigen::GpuDevice;
|
||||
namespace functor {
|
||||
// Forward declarations of the functor specializations defined in the sharded
|
||||
// files.
|
||||
#define DECLARE_CPU_SPEC(T, i) \
|
||||
template <> \
|
||||
void MirrorPad<CpuDevice, T, i>::operator()( \
|
||||
const CpuDevice&, typename TTypes<T, i, int32>::Tensor, \
|
||||
typename TTypes<T, i, int32>::ConstTensor, TTypes<int32>::ConstMatrix, \
|
||||
int); \
|
||||
extern template struct MirrorPad<CpuDevice, T, i>;
|
||||
#define DECLARE_CPU_SPEC(T, Tpaddings, i) \
|
||||
template <> \
|
||||
void MirrorPad<CpuDevice, T, Tpaddings, i>::operator()( \
|
||||
const CpuDevice&, typename TTypes<T, i, int32>::Tensor, \
|
||||
typename TTypes<T, i, int32>::ConstTensor, \
|
||||
TTypes<Tpaddings>::ConstMatrix, int); \
|
||||
extern template struct MirrorPad<CpuDevice, T, Tpaddings, i>;
|
||||
|
||||
#define DECLARE_CPU_SPECS(T) \
|
||||
DECLARE_CPU_SPEC(T, 1); \
|
||||
DECLARE_CPU_SPEC(T, 2); \
|
||||
DECLARE_CPU_SPEC(T, 3); \
|
||||
DECLARE_CPU_SPEC(T, 4); \
|
||||
DECLARE_CPU_SPEC(T, 5);
|
||||
#define DECLARE_CPU_SPECS(T) \
|
||||
DECLARE_CPU_SPEC(T, int32, 1); \
|
||||
DECLARE_CPU_SPEC(T, int32, 2); \
|
||||
DECLARE_CPU_SPEC(T, int32, 3); \
|
||||
DECLARE_CPU_SPEC(T, int32, 4); \
|
||||
DECLARE_CPU_SPEC(T, int32, 5); \
|
||||
DECLARE_CPU_SPEC(T, int64, 1); \
|
||||
DECLARE_CPU_SPEC(T, int64, 2); \
|
||||
DECLARE_CPU_SPEC(T, int64, 3); \
|
||||
DECLARE_CPU_SPEC(T, int64, 4); \
|
||||
DECLARE_CPU_SPEC(T, int64, 5);
|
||||
|
||||
TF_CALL_POD_TYPES(DECLARE_CPU_SPECS);
|
||||
|
||||
@ -179,7 +184,13 @@ TF_CALL_POD_TYPES(DECLARE_CPU_SPECS);
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadOp<CpuDevice, type>);
|
||||
MirrorPadOp<CpuDevice, type, int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MirrorPad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadOp<CpuDevice, type, int64>);
|
||||
|
||||
// Note that we do register for bool type, but not in the gradient op.
|
||||
TF_CALL_POD_TYPES(REGISTER_KERNEL);
|
||||
@ -188,20 +199,25 @@ TF_CALL_POD_TYPES(REGISTER_KERNEL);
|
||||
#if GOOGLE_CUDA
|
||||
namespace functor {
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
#define DECLARE_GPU_SPEC(T, i) \
|
||||
template <> \
|
||||
void MirrorPad<GpuDevice, T, i>::operator()( \
|
||||
const GpuDevice&, typename TTypes<T, i, int32>::Tensor, \
|
||||
typename TTypes<T, i, int32>::ConstTensor, TTypes<int32>::ConstMatrix, \
|
||||
int); \
|
||||
extern template struct MirrorPad<GpuDevice, T, i>;
|
||||
#define DECLARE_GPU_SPEC(T, Tpaddings, i) \
|
||||
template <> \
|
||||
void MirrorPad<GpuDevice, T, Tpaddings, i>::operator()( \
|
||||
const GpuDevice&, typename TTypes<T, i, int32>::Tensor, \
|
||||
typename TTypes<T, i, int32>::ConstTensor, \
|
||||
TTypes<Tpaddings>::ConstMatrix, int); \
|
||||
extern template struct MirrorPad<GpuDevice, T, Tpaddings, i>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPEC(T, 1); \
|
||||
DECLARE_GPU_SPEC(T, 2); \
|
||||
DECLARE_GPU_SPEC(T, 3); \
|
||||
DECLARE_GPU_SPEC(T, 4); \
|
||||
DECLARE_GPU_SPEC(T, 5);
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPEC(T, int32, 1); \
|
||||
DECLARE_GPU_SPEC(T, int32, 2); \
|
||||
DECLARE_GPU_SPEC(T, int32, 3); \
|
||||
DECLARE_GPU_SPEC(T, int32, 4); \
|
||||
DECLARE_GPU_SPEC(T, int32, 5); \
|
||||
DECLARE_GPU_SPEC(T, int64, 1); \
|
||||
DECLARE_GPU_SPEC(T, int64, 2); \
|
||||
DECLARE_GPU_SPEC(T, int64, 3); \
|
||||
DECLARE_GPU_SPEC(T, int64, 4); \
|
||||
DECLARE_GPU_SPEC(T, int64, 5);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
#undef DECLARE_GPU_SPECS
|
||||
@ -215,14 +231,20 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadOp<GpuDevice, T>)
|
||||
MirrorPadOp<GpuDevice, T, int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MirrorPad") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadOp<GpuDevice, T, int64>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// Gradient op.
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Tpaddings>
|
||||
class MirrorPadGradOp : public OpKernel {
|
||||
public:
|
||||
explicit MirrorPadGradOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
@ -269,10 +291,10 @@ class MirrorPadGradOp : public OpKernel {
|
||||
|
||||
// Compute the shape of the output tensor, and allocate it.
|
||||
TensorShape output_shape;
|
||||
TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>();
|
||||
typename TTypes<Tpaddings>::ConstMatrix paddings = in1.matrix<Tpaddings>();
|
||||
for (int d = 0; d < dims; ++d) {
|
||||
const int32 before = paddings(d, 0); // Pad before existing elements.
|
||||
const int32 after = paddings(d, 1); // Pad after existing elements.
|
||||
const Tpaddings before = paddings(d, 0); // Pad before existing elements.
|
||||
const Tpaddings after = paddings(d, 1); // Pad after existing elements.
|
||||
OP_REQUIRES(context, before >= 0 && after >= 0,
|
||||
errors::InvalidArgument("Paddings must be non-negative: ",
|
||||
before, ", ", after));
|
||||
@ -308,7 +330,7 @@ class MirrorPadGradOp : public OpKernel {
|
||||
|
||||
#define MIRROR_PAD_GRAD_CASE(k) \
|
||||
case k: { \
|
||||
functor::MirrorPadGrad<Device, T, k>()( \
|
||||
functor::MirrorPadGrad<Device, T, Tpaddings, k>()( \
|
||||
context->eigen_device<Device>(), To32Bit(output->tensor<T, k>()), \
|
||||
To32Bit(in0.tensor<T, k>()), paddings, offset_, \
|
||||
To32Bit(scratch.tensor<T, k>())); \
|
||||
@ -337,33 +359,45 @@ class MirrorPadGradOp : public OpKernel {
|
||||
namespace functor {
|
||||
// Forward declarations of the functor specializations defined in the sharded
|
||||
// files.
|
||||
#define DECLARE_CPU_SPEC(T, k) \
|
||||
template <> \
|
||||
void MirrorPadGrad<CpuDevice, T, k>::operator()( \
|
||||
const CpuDevice&, typename TTypes<T, k, int32>::Tensor, \
|
||||
typename TTypes<T, k, int32>::ConstTensor, TTypes<int32>::ConstMatrix, \
|
||||
int, typename TTypes<T, k, int32>::Tensor); \
|
||||
extern template struct MirrorPadGrad<CpuDevice, T, k>;
|
||||
#define DECLARE_CPU_SPEC(T, Tpaddings, k) \
|
||||
template <> \
|
||||
void MirrorPadGrad<CpuDevice, T, Tpaddings, k>::operator()( \
|
||||
const CpuDevice&, typename TTypes<T, k, int32>::Tensor, \
|
||||
typename TTypes<T, k, int32>::ConstTensor, \
|
||||
TTypes<Tpaddings>::ConstMatrix, int, \
|
||||
typename TTypes<T, k, int32>::Tensor); \
|
||||
extern template struct MirrorPadGrad<CpuDevice, T, Tpaddings, k>;
|
||||
|
||||
#define DECLARE_CPU_SPECS(T) \
|
||||
DECLARE_CPU_SPEC(T, 1); \
|
||||
DECLARE_CPU_SPEC(T, 2); \
|
||||
DECLARE_CPU_SPEC(T, 3); \
|
||||
DECLARE_CPU_SPEC(T, 4); \
|
||||
DECLARE_CPU_SPEC(T, 5);
|
||||
#define DECLARE_CPU_SPECS(T) \
|
||||
DECLARE_CPU_SPEC(T, int32, 1); \
|
||||
DECLARE_CPU_SPEC(T, int32, 2); \
|
||||
DECLARE_CPU_SPEC(T, int32, 3); \
|
||||
DECLARE_CPU_SPEC(T, int32, 4); \
|
||||
DECLARE_CPU_SPEC(T, int32, 5); \
|
||||
DECLARE_CPU_SPEC(T, int64, 1); \
|
||||
DECLARE_CPU_SPEC(T, int64, 2); \
|
||||
DECLARE_CPU_SPEC(T, int64, 3); \
|
||||
DECLARE_CPU_SPEC(T, int64, 4); \
|
||||
DECLARE_CPU_SPEC(T, int64, 5);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(DECLARE_CPU_SPECS);
|
||||
#undef DECLARE_CPU_SPECS
|
||||
#undef DECLARE_CPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadGradOp<CpuDevice, type>);
|
||||
#define REGISTER_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadGradOp<CpuDevice, type, int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadGradOp<CpuDevice, type, int64>);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
#undef REGISTER_KERNEL
|
||||
@ -371,20 +405,26 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
#if GOOGLE_CUDA
|
||||
namespace functor {
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
#define DECLARE_GPU_SPEC(T, k) \
|
||||
template <> \
|
||||
void MirrorPadGrad<GpuDevice, T, k>::operator()( \
|
||||
const GpuDevice&, typename TTypes<T, k, int32>::Tensor, \
|
||||
typename TTypes<T, k, int32>::ConstTensor, TTypes<int32>::ConstMatrix, \
|
||||
int, typename TTypes<T, k, int32>::Tensor); \
|
||||
extern template struct MirrorPadGrad<GpuDevice, T, k>;
|
||||
#define DECLARE_GPU_SPEC(T, Tpaddings, k) \
|
||||
template <> \
|
||||
void MirrorPadGrad<GpuDevice, T, Tpaddings, k>::operator()( \
|
||||
const GpuDevice&, typename TTypes<T, k, int32>::Tensor, \
|
||||
typename TTypes<T, k, int32>::ConstTensor, \
|
||||
TTypes<Tpaddings>::ConstMatrix, int, \
|
||||
typename TTypes<T, k, int32>::Tensor); \
|
||||
extern template struct MirrorPadGrad<GpuDevice, T, Tpaddings, k>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPEC(T, 1); \
|
||||
DECLARE_GPU_SPEC(T, 2); \
|
||||
DECLARE_GPU_SPEC(T, 3); \
|
||||
DECLARE_GPU_SPEC(T, 4); \
|
||||
DECLARE_GPU_SPEC(T, 5);
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPEC(T, int32, 1); \
|
||||
DECLARE_GPU_SPEC(T, int32, 2); \
|
||||
DECLARE_GPU_SPEC(T, int32, 3); \
|
||||
DECLARE_GPU_SPEC(T, int32, 4); \
|
||||
DECLARE_GPU_SPEC(T, int32, 5); \
|
||||
DECLARE_GPU_SPEC(T, int64, 1); \
|
||||
DECLARE_GPU_SPEC(T, int64, 2); \
|
||||
DECLARE_GPU_SPEC(T, int64, 3); \
|
||||
DECLARE_GPU_SPEC(T, int64, 4); \
|
||||
DECLARE_GPU_SPEC(T, int64, 5);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
#undef DECLARE_GPU_SPECS
|
||||
@ -398,7 +438,13 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadGradOp<GpuDevice, T>)
|
||||
MirrorPadGradOp<GpuDevice, T, int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MirrorPadGrad") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
MirrorPadGradOp<GpuDevice, T, int64>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
@ -64,9 +64,8 @@ class TensorMirrorPadOp
|
||||
StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorMirrorPadOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorMirrorPadOp(const XprType& expr, const PaddingDimensions& padding_dims,
|
||||
Index offset)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMirrorPadOp(
|
||||
const XprType& expr, const PaddingDimensions& padding_dims, Index offset)
|
||||
: xpr_(expr), padding_dims_(padding_dims), offset_(offset) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
@ -336,12 +335,12 @@ namespace functor {
|
||||
|
||||
// offset argument must be either 0 or 1. This controls whether the boundary
|
||||
// values are replicated (offset == 0) or not replicated (offset == 1).
|
||||
template <typename Device, typename T, int Dims>
|
||||
template <typename Device, typename T, typename Tpaddings, int Dims>
|
||||
struct MirrorPad {
|
||||
void operator()(const Device& device,
|
||||
typename TTypes<T, Dims, int32>::Tensor output,
|
||||
typename TTypes<T, Dims, int32>::ConstTensor input,
|
||||
TTypes<int32>::ConstMatrix padding, int offset) {
|
||||
typename TTypes<Tpaddings>::ConstMatrix padding, int offset) {
|
||||
Eigen::array<Eigen::IndexPair<int32>, Dims> padding_dims;
|
||||
|
||||
for (int i = 0; i < Dims; ++i) {
|
||||
@ -363,12 +362,12 @@ struct MirrorPad {
|
||||
|
||||
// offset argument must be either 0 or 1. This controls whether the boundary
|
||||
// values are replicated (offset == 0) or not replicated (offset == 1).
|
||||
template <typename Device, typename T, int Dims>
|
||||
template <typename Device, typename T, typename Tpaddings, int Dims>
|
||||
struct MirrorPadGrad {
|
||||
void operator()(const Device& device,
|
||||
typename TTypes<T, Dims, int32>::Tensor output,
|
||||
typename TTypes<T, Dims, int32>::ConstTensor input,
|
||||
TTypes<int32>::ConstMatrix paddings, int offset,
|
||||
typename TTypes<Tpaddings>::ConstMatrix paddings, int offset,
|
||||
typename TTypes<T, Dims, int32>::Tensor scratch) {
|
||||
// Copy the gradient input into the scratch buffer.
|
||||
scratch.device(device) = input;
|
||||
|
@ -25,13 +25,17 @@ namespace tensorflow {
|
||||
|
||||
using CpuDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
#define DEFINE_CPU_SPECS(T) \
|
||||
template struct functor::MirrorPad<CpuDevice, T, CPU_PROVIDED_IXDIM>;
|
||||
#define DEFINE_CPU_SPECS(T) \
|
||||
template struct functor::MirrorPad<CpuDevice, T, int32, CPU_PROVIDED_IXDIM>; \
|
||||
template struct functor::MirrorPad<CpuDevice, T, int64, CPU_PROVIDED_IXDIM>;
|
||||
TF_CALL_POD_TYPES(DEFINE_CPU_SPECS);
|
||||
#undef DEFINE_CPU_SPECS
|
||||
|
||||
#define DEFINE_CPU_SPECS(T) \
|
||||
template struct functor::MirrorPadGrad<CpuDevice, T, CPU_PROVIDED_IXDIM>;
|
||||
#define DEFINE_CPU_SPECS(T) \
|
||||
template struct functor::MirrorPadGrad<CpuDevice, T, int32, \
|
||||
CPU_PROVIDED_IXDIM>; \
|
||||
template struct functor::MirrorPadGrad<CpuDevice, T, int64, \
|
||||
CPU_PROVIDED_IXDIM>;
|
||||
TF_CALL_NUMBER_TYPES(DEFINE_CPU_SPECS);
|
||||
#undef DEFINE_CPU_SPECS
|
||||
|
||||
|
@ -25,17 +25,27 @@ namespace tensorflow {
|
||||
|
||||
using GpuDevice = Eigen::GpuDevice;
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
template struct functor::MirrorPad<GpuDevice, T, 1>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, 2>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, 3>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, 4>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, 5>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, 1>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, 2>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, 3>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, 4>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, 5>;
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int32, 1>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int32, 2>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int32, 3>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int32, 4>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int32, 5>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int64, 1>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int64, 2>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int64, 3>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int64, 4>; \
|
||||
template struct functor::MirrorPad<GpuDevice, T, int64, 5>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int32, 1>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int32, 2>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int32, 3>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int32, 4>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int32, 5>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int64, 1>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int64, 2>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int64, 3>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int64, 4>; \
|
||||
template struct functor::MirrorPadGrad<GpuDevice, T, int64, 5>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
#undef DEFINE_GPU_SPECS
|
||||
|
@ -193,6 +193,25 @@ class PadOpTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "Unknown padding mode"):
|
||||
array_ops.pad(x, [[1, 0], [2, 1]], mode="weird").eval()
|
||||
|
||||
def testPaddingTypes(self):
|
||||
paddings = [[1, 0], [2, 3], [0, 2]]
|
||||
inputs = np.random.randint(-100, 100, (4, 4, 3)).astype(np.float32)
|
||||
for mode in ("CONSTANT", "REFLECT", "SYMMETRIC", "reflect", "symmetric",
|
||||
"constant"):
|
||||
for padding_dtype in [dtypes.int32, dtypes.int64]:
|
||||
np_val = self._npPad(inputs,
|
||||
paddings,
|
||||
mode=mode,
|
||||
constant_values=0)
|
||||
with self.test_session(use_gpu=True):
|
||||
tf_val = array_ops.pad(inputs,
|
||||
constant_op.constant(paddings, padding_dtype),
|
||||
mode=mode,
|
||||
constant_values=0)
|
||||
out = tf_val.eval()
|
||||
self.assertAllEqual(np_val, out)
|
||||
self.assertShapeEqual(np_val, tf_val)
|
||||
|
||||
def testIntTypes(self):
|
||||
# TODO(touts): Figure out why the padding tests do not work on GPU
|
||||
# for int types and rank > 2.
|
||||
|
Loading…
Reference in New Issue
Block a user