Add int64
type multiples
support for tf.tile
(#13884)
* Add `int64` type `multiples` support for `tf.tile` In the doc of `tf.tile` (tf.tile.__doc__) both `int32` and `int64` are supported for `multiples`. However, the kernel for `int64` is not registered yet. This fix adds the support of `int64` `multiples` so that the behavior matches the description of the docs. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Update functors for int64 multiples support in `tf.tile` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Update test cases for int64 of multiples in `tf.tile` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add GPU and non GPU tests Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * format with clang-format -i Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Move Tmultiples after T (as it is auxilliary) And use `use_gpu=True` Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
40c475b48c
commit
690003cc01
tensorflow
core/kernels
python/kernel_tests
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -29,13 +30,13 @@ namespace internal {
|
||||
template <typename Device, typename T>
|
||||
void TileSimple(const Device& d, Tensor* out, const Tensor& in);
|
||||
|
||||
template <typename Device, typename T, int NDIM>
|
||||
template <typename Device, typename T, typename Tmultiples, int NDIM>
|
||||
void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
|
||||
const gtl::ArraySlice<int32>& broadcast_array) {
|
||||
const gtl::ArraySlice<Tmultiples>& broadcast_array) {
|
||||
auto x = in.tensor<T, NDIM>();
|
||||
auto y = out->tensor<T, NDIM>();
|
||||
|
||||
Eigen::array<int32, NDIM> b;
|
||||
Eigen::array<Tmultiples, NDIM> b;
|
||||
for (int i = 0; i < NDIM; ++i) b[i] = broadcast_array[i];
|
||||
if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
|
||||
// Use 32bit indexing to speed up the computations
|
||||
@ -45,9 +46,9 @@ void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Tmultiples>
|
||||
void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
|
||||
const gtl::ArraySlice<int32>&) {
|
||||
const gtl::ArraySlice<Tmultiples>&) {
|
||||
auto x = in.tensor<T, 0>();
|
||||
auto y = out->tensor<T, 0>();
|
||||
// In the scalar case we simply copy the input.
|
||||
@ -58,34 +59,42 @@ void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in,
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Tmultiples>
|
||||
struct Tile {
|
||||
void operator()(const Device& d, Tensor* out, const Tensor& in,
|
||||
const gtl::ArraySlice<int32> broadcast_array) const {
|
||||
const gtl::ArraySlice<Tmultiples> broadcast_array) const {
|
||||
switch (in.dims()) {
|
||||
case 0:
|
||||
internal::TileUsingEigen<Device, T>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
case 1:
|
||||
internal::TileUsingEigen<Device, T, 1>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples, 1>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
case 2:
|
||||
internal::TileUsingEigen<Device, T, 2>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples, 2>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
case 3:
|
||||
internal::TileUsingEigen<Device, T, 3>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples, 3>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
case 4:
|
||||
internal::TileUsingEigen<Device, T, 4>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples, 4>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
case 5:
|
||||
internal::TileUsingEigen<Device, T, 5>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples, 5>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
case 6:
|
||||
internal::TileUsingEigen<Device, T, 6>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples, 6>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
case 7:
|
||||
internal::TileUsingEigen<Device, T, 7>(d, out, in, broadcast_array);
|
||||
internal::TileUsingEigen<Device, T, Tmultiples, 7>(d, out, in,
|
||||
broadcast_array);
|
||||
break;
|
||||
default:
|
||||
internal::TileSimple<Device, T>(d, out, in);
|
||||
|
@ -15,10 +15,10 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/tile_functor.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/kernels/tile_functor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -51,7 +51,9 @@ namespace functor {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// Register functors used for Tile functor.
|
||||
#define DEFINE_TYPE(T) template struct Tile<CPUDevice, T>;
|
||||
#define DEFINE_TYPE(T) \
|
||||
template struct Tile<CPUDevice, T, int32>; \
|
||||
template struct Tile<CPUDevice, T, int64>;
|
||||
|
||||
TF_CALL_bool(DEFINE_TYPE);
|
||||
TF_CALL_float(DEFINE_TYPE);
|
||||
@ -70,7 +72,9 @@ TF_CALL_string(DEFINE_TYPE);
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
|
||||
#define DEFINE_TYPE(T) template struct Tile<SYCLDevice, T>;
|
||||
#define DEFINE_TYPE(T) \
|
||||
template struct Tile<SYCLDevice, T, int32>; \
|
||||
template struct Tile<SYCLDevice, T, int64>;
|
||||
|
||||
TF_CALL_bool(DEFINE_TYPE);
|
||||
TF_CALL_float(DEFINE_TYPE);
|
||||
@ -81,7 +85,7 @@ TF_CALL_int16(DEFINE_TYPE);
|
||||
TF_CALL_int64(DEFINE_TYPE);
|
||||
|
||||
#undef DEFINE_TYPE
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
@ -18,10 +18,11 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/tile_functor.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/kernels/tile_functor.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
@ -60,7 +61,8 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) {
|
||||
host_buf[ndims + i] = out_strides[i];
|
||||
host_buf[ndims * 2 + i] = in.dim_size(i);
|
||||
}
|
||||
// Copies the input strides, output strides and input dimension sizes to the device.
|
||||
// Copies the input strides, output strides and input dimension sizes to the
|
||||
// device.
|
||||
auto num_bytes = sizeof(int64) * host_buf.size();
|
||||
auto dev_buf = d.allocate(num_bytes);
|
||||
// NOTE: host_buf is not allocated by CudaHostAllocator, and
|
||||
@ -84,7 +86,9 @@ namespace functor {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Register functors used for Tile functor.
|
||||
#define DEFINE_TYPE(T) template struct Tile<GPUDevice, T>;
|
||||
#define DEFINE_TYPE(T) \
|
||||
template struct Tile<GPUDevice, T, int32>; \
|
||||
template struct Tile<GPUDevice, T, int64>;
|
||||
|
||||
TF_CALL_int16(DEFINE_TYPE);
|
||||
TF_CALL_int32(DEFINE_TYPE);
|
||||
|
@ -42,14 +42,14 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
// Forward declarations of functors that will be defined in tile_ops_impl.h
|
||||
namespace functor {
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Tmultiple>
|
||||
struct Tile {
|
||||
void operator()(const Device& d, Tensor* out, const Tensor& in,
|
||||
const gtl::ArraySlice<int32> broadcast_array) const;
|
||||
const gtl::ArraySlice<Tmultiple> broadcast_array) const;
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int NDIM>
|
||||
@ -80,7 +80,7 @@ struct ReduceAndReshape {
|
||||
} // namespace functor
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
template <typename Device>
|
||||
template <typename Device, typename Tmultiples>
|
||||
class TileOp : public OpKernel {
|
||||
public:
|
||||
explicit TileOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
@ -105,8 +105,8 @@ class TileOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
const gtl::ArraySlice<int32> multiples_array(multiples.flat<int32>().data(),
|
||||
input_dims);
|
||||
const gtl::ArraySlice<Tmultiples> multiples_array(
|
||||
multiples.flat<Tmultiples>().data(), input_dims);
|
||||
TensorShape output_shape;
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
OP_REQUIRES(
|
||||
@ -125,10 +125,10 @@ class TileOp : public OpKernel {
|
||||
// If there's no output, there's nothing to do.
|
||||
if (output_shape.num_elements() == 0) return;
|
||||
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (context->input(0).dtype() == DT) { \
|
||||
HandleCase<DT>(context, multiples_array, result); \
|
||||
return; \
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (context->input(0).dtype() == DT) { \
|
||||
HandleCase<DT>(context, multiples_array, result); \
|
||||
return; \
|
||||
}
|
||||
|
||||
#define HANDLE_TYPE_NAME(T) HANDLE_TYPE(DataTypeToEnum<T>::value)
|
||||
@ -158,27 +158,27 @@ class TileOp : public OpKernel {
|
||||
private:
|
||||
template <DataType DT>
|
||||
void HandleCaseImpl(OpKernelContext* context,
|
||||
const gtl::ArraySlice<int32>& multiples_array,
|
||||
const gtl::ArraySlice<Tmultiples>& multiples_array,
|
||||
Tensor* result) {
|
||||
typedef typename EnumToDataType<DT>::Type T;
|
||||
functor::Tile<Device, T>() (
|
||||
context->eigen_device<Device>(), result,
|
||||
context->input(0), multiples_array);
|
||||
functor::Tile<Device, T, Tmultiples>()(context->eigen_device<Device>(),
|
||||
result, context->input(0),
|
||||
multiples_array);
|
||||
}
|
||||
|
||||
template <DataType DT>
|
||||
void HandleCase(OpKernelContext* context,
|
||||
const gtl::ArraySlice<int32>& multiples_array,
|
||||
const gtl::ArraySlice<Tmultiples>& multiples_array,
|
||||
Tensor* result);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
template <typename Device, typename Tmultiples>
|
||||
template <DataType DT>
|
||||
inline void TileOp<Device>::HandleCase(
|
||||
OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array,
|
||||
Tensor* result) {
|
||||
inline void TileOp<Device, Tmultiples>::HandleCase(
|
||||
OpKernelContext* context,
|
||||
const gtl::ArraySlice<Tmultiples>& multiples_array, Tensor* result) {
|
||||
// TODO(vrv): print out the device name if useful. Currently disabled to avoid
|
||||
// having to use RTTI.
|
||||
LOG(FATAL) << "TileOp: Invalid combination of Device, DT: "
|
||||
@ -186,25 +186,28 @@ inline void TileOp<Device>::HandleCase(
|
||||
<< DataTypeString(DT);
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(device, dtype) \
|
||||
template <> \
|
||||
template <> \
|
||||
void TileOp<device>::HandleCase<dtype>( \
|
||||
OpKernelContext * context, \
|
||||
const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { \
|
||||
HandleCaseImpl<dtype>(context, multiples_array, result); \
|
||||
#define HANDLE_CASE(device, dtype, Tmultiples) \
|
||||
template <> \
|
||||
template <> \
|
||||
void TileOp<device, Tmultiples>::HandleCase<dtype>( \
|
||||
OpKernelContext * context, \
|
||||
const gtl::ArraySlice<Tmultiples>& multiples_array, Tensor* result) { \
|
||||
HandleCaseImpl<dtype>(context, multiples_array, result); \
|
||||
}
|
||||
|
||||
#define HANDLE_TYPE_NAME_CPU(T) \
|
||||
HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value);
|
||||
#define HANDLE_TYPE_NAME_CPU(T) \
|
||||
HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value, int32); \
|
||||
HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value, int64);
|
||||
|
||||
#define HANDLE_TYPE_NAME_GPU(T) \
|
||||
HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value);
|
||||
#define HANDLE_TYPE_NAME_GPU(T) \
|
||||
HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value, int32); \
|
||||
HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value, int64);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define HANDLE_TYPE_NAME_SYCL(T) \
|
||||
HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#define HANDLE_TYPE_NAME_SYCL(T) \
|
||||
HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value, int32); \
|
||||
HANDLE_CASE(SYCLDevice, DataTypeToEnum<T>::value, int64);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
TF_CALL_bool(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_float(HANDLE_TYPE_NAME_CPU);
|
||||
@ -235,13 +238,13 @@ TF_CALL_double(HANDLE_TYPE_NAME_SYCL);
|
||||
TF_CALL_int16(HANDLE_TYPE_NAME_SYCL);
|
||||
TF_CALL_int32(HANDLE_TYPE_NAME_SYCL);
|
||||
TF_CALL_int64(HANDLE_TYPE_NAME_SYCL);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#undef HANDLE_TYPE_NAME_CPU
|
||||
#undef HANDLE_TYPE_NAME_GPU
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#undef HANDLE_TYPE_NAME_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#undef HANDLE_CASE
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@ -494,7 +497,7 @@ TF_CALL_int16(HANDLE_TYPE_NAME_SYCL);
|
||||
TF_CALL_int32(HANDLE_TYPE_NAME_SYCL);
|
||||
TF_CALL_int64(HANDLE_TYPE_NAME_SYCL);
|
||||
#undef HANDLE_TYPE_NAME_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#undef HANDLE_TYPE_NAME_CPU
|
||||
#undef HANDLE_TYPE_NAME_GPU
|
||||
@ -505,127 +508,73 @@ REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_CPU)
|
||||
.HostMemory("multiples")
|
||||
.TypeConstraint<int32>("Tmultiples"),
|
||||
TileOp<CPUDevice>);
|
||||
TileOp<CPUDevice, int32>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_CPU)
|
||||
.HostMemory("multiples")
|
||||
.TypeConstraint<int64>("Tmultiples"),
|
||||
TileOp<CPUDevice, int64>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("TileGrad").Device(DEVICE_CPU).HostMemory("multiples"),
|
||||
TileGradientOp<CPUDevice>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tmultiples") \
|
||||
.HostMemory("multiples"), \
|
||||
TileOp<GPUDevice, int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tmultiples") \
|
||||
.HostMemory("multiples"), \
|
||||
TileOp<GPUDevice, int64>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tmultiples") \
|
||||
.HostMemory("multiples"), \
|
||||
TileGradientOp<GPUDevice>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<double>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<Eigen::half>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int16>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int32>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<complex64>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<complex128>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<double>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<Eigen::half>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int16>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int32>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<complex64>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<complex128>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
TF_CALL_int16(REGISTER_GPU);
|
||||
TF_CALL_int32(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU)
|
||||
|
||||
#undef REGISTER_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<float>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<SYCLDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<double>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<SYCLDevice>);
|
||||
#define REGISTER_SYCL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tmultiples") \
|
||||
.HostMemory("multiples"), \
|
||||
TileOp<SYCLDevice, int32>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tmultiples") \
|
||||
.HostMemory("multiples"), \
|
||||
TileOp<SYCLDevice, int64>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tmultiples") \
|
||||
.HostMemory("multiples"), \
|
||||
TileGradientOp<SYCLDevice>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<float>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<SYCLDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<double>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<SYCLDevice>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
TF_CALL_float(REGISTER_SYCL);
|
||||
TF_CALL_double(REGISTER_SYCL);
|
||||
|
||||
#undef REGISTER_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -411,14 +411,16 @@ class TileTest(test.TestCase):
|
||||
self.assertEqual(7, result)
|
||||
|
||||
def testSimple(self):
|
||||
with self.test_session():
|
||||
inp = np.random.rand(4, 1).astype(np.float32)
|
||||
a = constant_op.constant(inp)
|
||||
tiled = array_ops.tile(a, [1, 4])
|
||||
result = tiled.eval()
|
||||
self.assertEqual(result.shape, (4, 4))
|
||||
self.assertEqual([4, 4], tiled.get_shape())
|
||||
self.assertTrue((result == np.tile(inp, (1, 4))).all())
|
||||
# multiples could be int32 or int64
|
||||
for dtype in [dtypes.int32, dtypes.int64]:
|
||||
with self.test_session(use_gpu=True):
|
||||
inp = np.random.rand(4, 1).astype(np.float32)
|
||||
a = constant_op.constant(inp)
|
||||
tiled = array_ops.tile(a, constant_op.constant([1, 4], dtype=dtype))
|
||||
result = tiled.eval()
|
||||
self.assertEqual(result.shape, (4, 4))
|
||||
self.assertEqual([4, 4], tiled.get_shape())
|
||||
self.assertTrue((result == np.tile(inp, (1, 4))).all())
|
||||
|
||||
def testIdentityTileAndGrad(self):
|
||||
with self.test_session():
|
||||
|
Loading…
Reference in New Issue
Block a user