Add int64 type multiples support for tf.tile ()

* Add `int64` type `multiples` support for `tf.tile`

In the doc of `tf.tile` (tf.tile.__doc__) both `int32`
and `int64` are supported for `multiples`. However, the kernel
for `int64` is not registered yet.

This fix adds the support of `int64` `multiples` so that the
behavior matches the description of the docs.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Update functors for int64 multiples support in `tf.tile`

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Update test cases for int64 of multiples in `tf.tile`

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add GPU and non GPU tests

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* format with clang-format -i

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Move Tmultiples after T (as it is  auxilliary)

And use `use_gpu=True`

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-10-22 22:33:18 -07:00 committed by Vijay Vasudevan
parent 40c475b48c
commit 690003cc01
5 changed files with 147 additions and 179 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_KERNELS_TILE_FUNCTOR_H_ #define TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -29,13 +30,13 @@ namespace internal {
template <typename Device, typename T> template <typename Device, typename T>
void TileSimple(const Device& d, Tensor* out, const Tensor& in); void TileSimple(const Device& d, Tensor* out, const Tensor& in);
template <typename Device, typename T, int NDIM> template <typename Device, typename T, typename Tmultiples, int NDIM>
void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in, void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int32>& broadcast_array) { const gtl::ArraySlice<Tmultiples>& broadcast_array) {
auto x = in.tensor<T, NDIM>(); auto x = in.tensor<T, NDIM>();
auto y = out->tensor<T, NDIM>(); auto y = out->tensor<T, NDIM>();
Eigen::array<int32, NDIM> b; Eigen::array<Tmultiples, NDIM> b;
for (int i = 0; i < NDIM; ++i) b[i] = broadcast_array[i]; for (int i = 0; i < NDIM; ++i) b[i] = broadcast_array[i];
if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) { if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
// Use 32bit indexing to speed up the computations // Use 32bit indexing to speed up the computations
@ -45,9 +46,9 @@ void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
} }
} }
template <typename Device, typename T> template <typename Device, typename T, typename Tmultiples>
void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in, void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int32>&) { const gtl::ArraySlice<Tmultiples>&) {
auto x = in.tensor<T, 0>(); auto x = in.tensor<T, 0>();
auto y = out->tensor<T, 0>(); auto y = out->tensor<T, 0>();
// In the scalar case we simply copy the input. // In the scalar case we simply copy the input.
@ -58,34 +59,42 @@ void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
namespace functor { namespace functor {
template <typename Device, typename T> template <typename Device, typename T, typename Tmultiples>
struct Tile { struct Tile {
void operator()(const Device& d, Tensor* out, const Tensor& in, void operator()(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int32> broadcast_array) const { const gtl::ArraySlice<Tmultiples> broadcast_array) const {
switch (in.dims()) { switch (in.dims()) {
case 0: case 0:
internal::TileUsingEigen<Device, T>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples>(d, out, in,
broadcast_array);
break; break;
case 1: case 1:
internal::TileUsingEigen<Device, T, 1>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples, 1>(d, out, in,
broadcast_array);
break; break;
case 2: case 2:
internal::TileUsingEigen<Device, T, 2>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples, 2>(d, out, in,
broadcast_array);
break; break;
case 3: case 3:
internal::TileUsingEigen<Device, T, 3>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples, 3>(d, out, in,
broadcast_array);
break; break;
case 4: case 4:
internal::TileUsingEigen<Device, T, 4>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples, 4>(d, out, in,
broadcast_array);
break; break;
case 5: case 5:
internal::TileUsingEigen<Device, T, 5>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples, 5>(d, out, in,
broadcast_array);
break; break;
case 6: case 6:
internal::TileUsingEigen<Device, T, 6>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples, 6>(d, out, in,
broadcast_array);
break; break;
case 7: case 7:
internal::TileUsingEigen<Device, T, 7>(d, out, in, broadcast_array); internal::TileUsingEigen<Device, T, Tmultiples, 7>(d, out, in,
broadcast_array);
break; break;
default: default:
internal::TileSimple<Device, T>(d, out, in); internal::TileSimple<Device, T>(d, out, in);

View File

@ -15,10 +15,10 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/tile_functor.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/tile_functor.h"
namespace tensorflow { namespace tensorflow {
@ -51,7 +51,9 @@ namespace functor {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
// Register functors used for Tile functor. // Register functors used for Tile functor.
#define DEFINE_TYPE(T) template struct Tile<CPUDevice, T>; #define DEFINE_TYPE(T) \
template struct Tile<CPUDevice, T, int32>; \
template struct Tile<CPUDevice, T, int64>;
TF_CALL_bool(DEFINE_TYPE); TF_CALL_bool(DEFINE_TYPE);
TF_CALL_float(DEFINE_TYPE); TF_CALL_float(DEFINE_TYPE);
@ -70,7 +72,9 @@ TF_CALL_string(DEFINE_TYPE);
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice; typedef Eigen::SyclDevice SYCLDevice;
#define DEFINE_TYPE(T) template struct Tile<SYCLDevice, T>; #define DEFINE_TYPE(T) \
template struct Tile<SYCLDevice, T, int32>; \
template struct Tile<SYCLDevice, T, int64>;
TF_CALL_bool(DEFINE_TYPE); TF_CALL_bool(DEFINE_TYPE);
TF_CALL_float(DEFINE_TYPE); TF_CALL_float(DEFINE_TYPE);
@ -81,7 +85,7 @@ TF_CALL_int16(DEFINE_TYPE);
TF_CALL_int64(DEFINE_TYPE); TF_CALL_int64(DEFINE_TYPE);
#undef DEFINE_TYPE #undef DEFINE_TYPE
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
} // end namespace functor } // end namespace functor
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -18,10 +18,11 @@ limitations under the License.
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/tile_functor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/tile_functor.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow { namespace tensorflow {
namespace internal { namespace internal {
@ -60,7 +61,8 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
host_buf[ndims + i] = out_strides[i]; host_buf[ndims + i] = out_strides[i];
host_buf[ndims * 2 + i] = in.dim_size(i); host_buf[ndims * 2 + i] = in.dim_size(i);
} }
// Copies the input strides, output strides and input dimension sizes to the device. // Copies the input strides, output strides and input dimension sizes to the
// device.
auto num_bytes = sizeof(int64) * host_buf.size(); auto num_bytes = sizeof(int64) * host_buf.size();
auto dev_buf = d.allocate(num_bytes); auto dev_buf = d.allocate(num_bytes);
// NOTE: host_buf is not allocated by CudaHostAllocator, and // NOTE: host_buf is not allocated by CudaHostAllocator, and
@ -84,7 +86,9 @@ namespace functor {
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
// Register functors used for Tile functor. // Register functors used for Tile functor.
#define DEFINE_TYPE(T) template struct Tile<GPUDevice, T>; #define DEFINE_TYPE(T) \
template struct Tile<GPUDevice, T, int32>; \
template struct Tile<GPUDevice, T, int64>;
TF_CALL_int16(DEFINE_TYPE); TF_CALL_int16(DEFINE_TYPE);
TF_CALL_int32(DEFINE_TYPE); TF_CALL_int32(DEFINE_TYPE);

View File

@ -42,14 +42,14 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice; typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
// Forward declarations of functors that will be defined in tile_ops_impl.h // Forward declarations of functors that will be defined in tile_ops_impl.h
namespace functor { namespace functor {
template <typename Device, typename T> template <typename Device, typename T, typename Tmultiple>
struct Tile { struct Tile {
void operator()(const Device& d, Tensor* out, const Tensor& in, void operator()(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int32> broadcast_array) const; const gtl::ArraySlice<Tmultiple> broadcast_array) const;
}; };
template <typename Device, typename T, int NDIM> template <typename Device, typename T, int NDIM>
@ -80,7 +80,7 @@ struct ReduceAndReshape {
} // namespace functor } // namespace functor
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
template <typename Device> template <typename Device, typename Tmultiples>
class TileOp : public OpKernel { class TileOp : public OpKernel {
public: public:
explicit TileOp(OpKernelConstruction* context) : OpKernel(context) {} explicit TileOp(OpKernelConstruction* context) : OpKernel(context) {}
@ -105,8 +105,8 @@ class TileOp : public OpKernel {
return; return;
} }
const gtl::ArraySlice<int32> multiples_array(multiples.flat<int32>().data(), const gtl::ArraySlice<Tmultiples> multiples_array(
input_dims); multiples.flat<Tmultiples>().data(), input_dims);
TensorShape output_shape; TensorShape output_shape;
for (int i = 0; i < input_dims; ++i) { for (int i = 0; i < input_dims; ++i) {
OP_REQUIRES( OP_REQUIRES(
@ -125,10 +125,10 @@ class TileOp : public OpKernel {
// If there's no output, there's nothing to do. // If there's no output, there's nothing to do.
if (output_shape.num_elements() == 0) return; if (output_shape.num_elements() == 0) return;
#define HANDLE_TYPE(DT) \ #define HANDLE_TYPE(DT) \
if (context->input(0).dtype() == DT) { \ if (context->input(0).dtype() == DT) { \
HandleCase<DT>(context, multiples_array, result); \ HandleCase<DT>(context, multiples_array, result); \
return; \ return; \
} }
#define HANDLE_TYPE_NAME(T) HANDLE_TYPE(DataTypeToEnum<T>::value) #define HANDLE_TYPE_NAME(T) HANDLE_TYPE(DataTypeToEnum<T>::value)
@ -158,27 +158,27 @@ class TileOp : public OpKernel {
private: private:
template <DataType DT> template <DataType DT>
void HandleCaseImpl(OpKernelContext* context, void HandleCaseImpl(OpKernelContext* context,
const gtl::ArraySlice<int32>& multiples_array, const gtl::ArraySlice<Tmultiples>& multiples_array,
Tensor* result) { Tensor* result) {
typedef typename EnumToDataType<DT>::Type T; typedef typename EnumToDataType<DT>::Type T;
functor::Tile<Device, T>() ( functor::Tile<Device, T, Tmultiples>()(context->eigen_device<Device>(),
context->eigen_device<Device>(), result, result, context->input(0),
context->input(0), multiples_array); multiples_array);
} }
template <DataType DT> template <DataType DT>
void HandleCase(OpKernelContext* context, void HandleCase(OpKernelContext* context,
const gtl::ArraySlice<int32>& multiples_array, const gtl::ArraySlice<Tmultiples>& multiples_array,
Tensor* result); Tensor* result);
TF_DISALLOW_COPY_AND_ASSIGN(TileOp); TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
}; };
template <typename Device> template <typename Device, typename Tmultiples>
template <DataType DT> template <DataType DT>
inline void TileOp<Device>::HandleCase( inline void TileOp<Device, Tmultiples>::HandleCase(
OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array, OpKernelContext* context,
Tensor* result) { const gtl::ArraySlice<Tmultiples>& multiples_array, Tensor* result) {
// TODO(vrv): print out the device name if useful. Currently disabled to avoid // TODO(vrv): print out the device name if useful. Currently disabled to avoid
// having to use RTTI. // having to use RTTI.
LOG(FATAL) << "TileOp: Invalid combination of Device, DT: " LOG(FATAL) << "TileOp: Invalid combination of Device, DT: "
@ -186,25 +186,28 @@ inline void TileOp<Device>::HandleCase(
<< DataTypeString(DT); << DataTypeString(DT);
} }
#define HANDLE_CASE(device, dtype) \ #define HANDLE_CASE(device, dtype, Tmultiples) \
template <> \ template <> \
template <> \ template <> \
void TileOp<device>::HandleCase<dtype>( \ void TileOp<device, Tmultiples>::HandleCase<dtype>( \
OpKernelContext * context, \ OpKernelContext * context, \
const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { \ const gtl::ArraySlice<Tmultiples>& multiples_array, Tensor* result) { \
HandleCaseImpl<dtype>(context, multiples_array, result); \ HandleCaseImpl<dtype>(context, multiples_array, result); \
} }
#define HANDLE_TYPE_NAME_CPU(T) \ #define HANDLE_TYPE_NAME_CPU(T) \
HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value); HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value, int32); \
HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value, int64);
#define HANDLE_TYPE_NAME_GPU(T) \ #define HANDLE_TYPE_NAME_GPU(T) \
HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value); HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value, int32); \
HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value, int64);
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
#define HANDLE_TYPE_NAME_SYCL(T) \ #define HANDLE_TYPE_NAME_SYCL(T) \
HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value); HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value, int32); \
#endif // TENSORFLOW_USE_SYCL HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value, int64);
#endif // TENSORFLOW_USE_SYCL
TF_CALL_bool(HANDLE_TYPE_NAME_CPU); TF_CALL_bool(HANDLE_TYPE_NAME_CPU);
TF_CALL_float(HANDLE_TYPE_NAME_CPU); TF_CALL_float(HANDLE_TYPE_NAME_CPU);
@ -235,13 +238,13 @@ TF_CALL_double(HANDLE_TYPE_NAME_SYCL);
TF_CALL_int16(HANDLE_TYPE_NAME_SYCL); TF_CALL_int16(HANDLE_TYPE_NAME_SYCL);
TF_CALL_int32(HANDLE_TYPE_NAME_SYCL); TF_CALL_int32(HANDLE_TYPE_NAME_SYCL);
TF_CALL_int64(HANDLE_TYPE_NAME_SYCL); TF_CALL_int64(HANDLE_TYPE_NAME_SYCL);
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
#undef HANDLE_TYPE_NAME_CPU #undef HANDLE_TYPE_NAME_CPU
#undef HANDLE_TYPE_NAME_GPU #undef HANDLE_TYPE_NAME_GPU
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
#undef HANDLE_TYPE_NAME_SYCL #undef HANDLE_TYPE_NAME_SYCL
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
#undef HANDLE_CASE #undef HANDLE_CASE
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
@ -494,7 +497,7 @@ TF_CALL_int16(HANDLE_TYPE_NAME_SYCL);
TF_CALL_int32(HANDLE_TYPE_NAME_SYCL); TF_CALL_int32(HANDLE_TYPE_NAME_SYCL);
TF_CALL_int64(HANDLE_TYPE_NAME_SYCL); TF_CALL_int64(HANDLE_TYPE_NAME_SYCL);
#undef HANDLE_TYPE_NAME_SYCL #undef HANDLE_TYPE_NAME_SYCL
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
#undef HANDLE_TYPE_NAME_CPU #undef HANDLE_TYPE_NAME_CPU
#undef HANDLE_TYPE_NAME_GPU #undef HANDLE_TYPE_NAME_GPU
@ -505,127 +508,73 @@ REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_CPU) .Device(DEVICE_CPU)
.HostMemory("multiples") .HostMemory("multiples")
.TypeConstraint<int32>("Tmultiples"), .TypeConstraint<int32>("Tmultiples"),
TileOp<CPUDevice>); TileOp<CPUDevice, int32>);
REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_CPU)
.HostMemory("multiples")
.TypeConstraint<int64>("Tmultiples"),
TileOp<CPUDevice, int64>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("TileGrad").Device(DEVICE_CPU).HostMemory("multiples"), Name("TileGrad").Device(DEVICE_CPU).HostMemory("multiples"),
TileGradientOp<CPUDevice>); TileGradientOp<CPUDevice>);
#if GOOGLE_CUDA #if GOOGLE_CUDA
#define REGISTER_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("Tile") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tmultiples") \
.HostMemory("multiples"), \
TileOp<GPUDevice, int32>); \
REGISTER_KERNEL_BUILDER(Name("Tile") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tmultiples") \
.HostMemory("multiples"), \
TileOp<GPUDevice, int64>); \
REGISTER_KERNEL_BUILDER(Name("TileGrad") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tmultiples") \
.HostMemory("multiples"), \
TileGradientOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("Tile") TF_CALL_float(REGISTER_GPU);
.Device(DEVICE_GPU) TF_CALL_double(REGISTER_GPU);
.TypeConstraint<float>("T") TF_CALL_half(REGISTER_GPU);
.TypeConstraint<int32>("Tmultiples") TF_CALL_int16(REGISTER_GPU);
.HostMemory("multiples"), TF_CALL_int32(REGISTER_GPU);
TileOp<GPUDevice>); TF_CALL_complex64(REGISTER_GPU);
REGISTER_KERNEL_BUILDER(Name("Tile") TF_CALL_complex128(REGISTER_GPU)
.Device(DEVICE_GPU)
.TypeConstraint<double>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_GPU)
.TypeConstraint<int16>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_GPU)
.TypeConstraint<complex64>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("Tile")
.Device(DEVICE_GPU)
.TypeConstraint<complex128>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_GPU)
.TypeConstraint<double>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_GPU)
.TypeConstraint<int16>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_GPU)
.TypeConstraint<complex64>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<GPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_GPU)
.TypeConstraint<complex128>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<GPUDevice>);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("Tile") #define REGISTER_SYCL(type) \
.Device(DEVICE_SYCL) REGISTER_KERNEL_BUILDER(Name("Tile") \
.TypeConstraint<float>("T") .Device(DEVICE_SYCL) \
.TypeConstraint<int32>("Tmultiples") .TypeConstraint<type>("T") \
.HostMemory("multiples"), .TypeConstraint<int32>("Tmultiples") \
TileOp<SYCLDevice>); .HostMemory("multiples"), \
REGISTER_KERNEL_BUILDER(Name("Tile") TileOp<SYCLDevice, int32>); \
.Device(DEVICE_SYCL) REGISTER_KERNEL_BUILDER(Name("Tile") \
.TypeConstraint<double>("T") .Device(DEVICE_SYCL) \
.TypeConstraint<int32>("Tmultiples") .TypeConstraint<type>("T") \
.HostMemory("multiples"), .TypeConstraint<int64>("Tmultiples") \
TileOp<SYCLDevice>); .HostMemory("multiples"), \
TileOp<SYCLDevice, int64>); \
REGISTER_KERNEL_BUILDER(Name("TileGrad") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tmultiples") \
.HostMemory("multiples"), \
TileGradientOp<SYCLDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad") TF_CALL_float(REGISTER_SYCL);
.Device(DEVICE_SYCL) TF_CALL_double(REGISTER_SYCL);
.TypeConstraint<float>("T")
.TypeConstraint<int32>("Tmultiples") #undef REGISTER_SYCL
.HostMemory("multiples"), #endif // TENSORFLOW_USE_SYCL
TileGradientOp<SYCLDevice>);
REGISTER_KERNEL_BUILDER(Name("TileGrad")
.Device(DEVICE_SYCL)
.TypeConstraint<double>("T")
.TypeConstraint<int32>("Tmultiples")
.HostMemory("multiples"),
TileGradientOp<SYCLDevice>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow } // namespace tensorflow

View File

@ -411,14 +411,16 @@ class TileTest(test.TestCase):
self.assertEqual(7, result) self.assertEqual(7, result)
def testSimple(self): def testSimple(self):
with self.test_session(): # multiples could be int32 or int64
inp = np.random.rand(4, 1).astype(np.float32) for dtype in [dtypes.int32, dtypes.int64]:
a = constant_op.constant(inp) with self.test_session(use_gpu=True):
tiled = array_ops.tile(a, [1, 4]) inp = np.random.rand(4, 1).astype(np.float32)
result = tiled.eval() a = constant_op.constant(inp)
self.assertEqual(result.shape, (4, 4)) tiled = array_ops.tile(a, constant_op.constant([1, 4], dtype=dtype))
self.assertEqual([4, 4], tiled.get_shape()) result = tiled.eval()
self.assertTrue((result == np.tile(inp, (1, 4))).all()) self.assertEqual(result.shape, (4, 4))
self.assertEqual([4, 4], tiled.get_shape())
self.assertTrue((result == np.tile(inp, (1, 4))).all())
def testIdentityTileAndGrad(self): def testIdentityTileAndGrad(self):
with self.test_session(): with self.test_session():