Add int64 axis support for reduction ops. (#13891)

* Add int64 axis support for reduction ops.

This fix is a follow up to PR 13863. In PR 13863 the
program crash is fixed if int64 axis is passed to reduction ops,
e.g. reduce_sum, reduce_max, etc. However, 13863 does not
process the case of int64 support, it merely fixes the crash.

This fix adds the support for int64 axis of reduction ops.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add int64 axis support for mean, prod, sum

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add int64 axis support for min and max.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add int64 axis support for reduce_all and reduce_any

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test cases for int64 axis support of reduce_any and reduce_all

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-10-21 23:01:42 -07:00 committed by Vijay Vasudevan
parent 17096081ee
commit 1c1dad105a
10 changed files with 391 additions and 148 deletions

View File

@ -22,7 +22,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_CPU)
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, bool, Eigen::internal::AndReducer>);
ReductionOp<CPUDevice, bool, int32, Eigen::internal::AndReducer>);
REGISTER_KERNEL_BUILDER(
Name("All")
.TypeConstraint<int64>("Tidx")
.Device(DEVICE_CPU)
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, bool, int64, Eigen::internal::AndReducer>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
@ -30,7 +36,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices"),
ReductionOp<GPUDevice, bool, Eigen::internal::AndReducer>);
ReductionOp<GPUDevice, bool, int32, Eigen::internal::AndReducer>);
REGISTER_KERNEL_BUILDER(
Name("All")
.TypeConstraint<int64>("Tidx")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices"),
ReductionOp<GPUDevice, bool, int64, Eigen::internal::AndReducer>);
#endif
} // namespace tensorflow

View File

@ -22,7 +22,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_CPU)
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, bool, Eigen::internal::OrReducer>);
ReductionOp<CPUDevice, bool, int32, Eigen::internal::OrReducer>);
REGISTER_KERNEL_BUILDER(
Name("Any")
.TypeConstraint<int64>("Tidx")
.Device(DEVICE_CPU)
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, bool, int64, Eigen::internal::OrReducer>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
@ -30,7 +36,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices"),
ReductionOp<GPUDevice, bool, Eigen::internal::OrReducer>);
ReductionOp<GPUDevice, bool, int32, Eigen::internal::OrReducer>);
REGISTER_KERNEL_BUILDER(
Name("Any")
.TypeConstraint<int64>("Tidx")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices"),
ReductionOp<GPUDevice, bool, int64, Eigen::internal::OrReducer>);
#endif
} // namespace tensorflow

View File

@ -57,13 +57,12 @@ gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
return perm;
}
Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
const bool keep_dims) {
// bitmap[i] indicates whether to reduce data along i-th axis.
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
auto axis_vec = axis.flat<int32>();
template <typename Tperm>
Status SimplifyHelper(const Tensor& data, const Tensor& axis,
gtl::InlinedVector<bool, 4>& bitmap) {
auto axis_vec = axis.flat<Tperm>();
for (int64 i = 0; i < axis.NumElements(); ++i) {
int32 index = axis_vec(i);
Tperm index = axis_vec(i);
if (index < -data.dims() || index >= data.dims()) {
return errors::InvalidArgument("Invalid reduction dimension (", index,
" for input with ", data.dims(),
@ -72,7 +71,18 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
index = (index + data.dims()) % data.dims();
bitmap[index] = true;
}
return Status::OK();
}
Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
const bool keep_dims) {
// bitmap[i] indicates whether to reduce data along i-th axis.
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
if (axis.dtype() == DT_INT32) {
TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
} else {
TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
}
// Output tensor's dim sizes.
out_shape_.clear();
for (int i = 0; i < data.dims(); ++i) {

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -42,7 +43,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
template <typename Device>
struct Constants {
@ -68,11 +69,13 @@ struct ConstantsBase {
const Eigen::IndexList<Eigen::type2index<1>> kOne;
const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
};
template<> struct Constants<CPUDevice> : ConstantsBase{};
template <>
struct Constants<CPUDevice> : ConstantsBase {};
#ifdef TENSORFLOW_USE_SYCL
template<> struct Constants<SYCLDevice> : ConstantsBase{};
#endif // TENSORFLOW_USE_SYCL
#endif // EIGEN_HAS_INDEX_LIST
template <>
struct Constants<SYCLDevice> : ConstantsBase {};
#endif // TENSORFLOW_USE_SYCL
#endif // EIGEN_HAS_INDEX_LIST
class ReductionHelper {
public:
@ -131,12 +134,13 @@ class ReductionHelper {
// For operations where the output is a reduction function along some
// dimensions of the input.
template <typename Device, class T, typename Reducer>
template <typename Device, class T, typename Tperm, typename Reducer>
class ReductionOp : public OpKernel {
public:
explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
const DataType dt = DataTypeToEnum<T>::v();
OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
const DataType pt = DataTypeToEnum<Tperm>::v();
OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, pt}, {dt}));
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
}
@ -266,20 +270,19 @@ struct ReduceFunctorBase {
}
template <typename OUT_T>
static void FillIdentity(const Device& d, OUT_T out,
const Reducer& reducer) {
static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) {
FillIdentityEigenImpl(d, out, reducer);
}
};
template <typename Reducer>
struct ReduceFunctor<CPUDevice, Reducer>
: ReduceFunctorBase<CPUDevice, Reducer>{};
: ReduceFunctorBase<CPUDevice, Reducer> {};
#if TENSORFLOW_USE_SYCL
template <typename Reducer>
struct ReduceFunctor<SYCLDevice, Reducer>
: ReduceFunctorBase<SYCLDevice, Reducer>{};
#endif // TENSORFLOW_USE_SYCL
: ReduceFunctorBase<SYCLDevice, Reducer> {};
#endif // TENSORFLOW_USE_SYCL
} // namespace functor
} // namespace tensorflow

View File

@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, Eigen::internal::MaxReducer<type>>);
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::MaxReducer<type>>);
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
REGISTER_GPU_KERNELS(int64);
@ -52,21 +65,37 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
ReductionOp<CPUDevice, int32, Eigen::internal::MaxReducer<int32>>);
ReductionOp<CPUDevice, int32, int32, Eigen::internal::MaxReducer<int32>>);
REGISTER_KERNEL_BUILDER(
Name("Max")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices")
.HostMemory("input")
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("Tidx"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::MaxReducer<int32>>);
#undef REGISTER_GPU_KERNELS
#endif
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, Eigen::internal::MaxReducer<type>>);
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Max") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int32, \
Eigen::internal::MaxReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Max") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int64, \
Eigen::internal::MaxReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
@ -78,8 +107,17 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
ReductionOp<CPUDevice, int32, Eigen::internal::MaxReducer<int32>>);
ReductionOp<CPUDevice, int32, int32, Eigen::internal::MaxReducer<int32>>);
REGISTER_KERNEL_BUILDER(
Name("Max")
.Device(DEVICE_SYCL)
.HostMemory("reduction_indices")
.HostMemory("input")
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("Tidx"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::MaxReducer<int32>>);
#undef REGISTER_SYCL_KERNELS
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Mean") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, Eigen::internal::MeanReducer<type>>);
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Mean") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, \
Eigen::internal::MeanReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Mean") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, \
Eigen::internal::MeanReducer<type>>);
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Mean") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::MeanReducer<type>>);
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Mean") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, \
Eigen::internal::MeanReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Mean") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, \
Eigen::internal::MeanReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
@ -45,17 +58,24 @@ TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Mean") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, Eigen::internal::MeanReducer<type>>);
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Mean") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int32, \
Eigen::internal::MeanReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Mean") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int64, \
Eigen::internal::MeanReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
#undef REGISTER_SYCL_KERNELS
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, Eigen::internal::MinReducer<type>>);
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::MinReducer<type>>);
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
@ -51,21 +64,37 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
ReductionOp<CPUDevice, int32, Eigen::internal::MinReducer<int32>>);
ReductionOp<CPUDevice, int32, int32, Eigen::internal::MinReducer<int32>>);
REGISTER_KERNEL_BUILDER(
Name("Min")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices")
.HostMemory("input")
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("Tidx"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::MinReducer<int32>>);
#undef REGISTER_GPU_KERNELS
#endif
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, Eigen::internal::MinReducer<type>>);
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Min") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int32, \
Eigen::internal::MinReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Min") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int64, \
Eigen::internal::MinReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
@ -77,8 +106,17 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
ReductionOp<CPUDevice, int32, Eigen::internal::MinReducer<int32>>);
ReductionOp<CPUDevice, int32, int32, Eigen::internal::MinReducer<int32>>);
REGISTER_KERNEL_BUILDER(
Name("Min")
.Device(DEVICE_SYCL)
.HostMemory("reduction_indices")
.HostMemory("input")
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("Tidx"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::MinReducer<int32>>);
#undef REGISTER_SYCL_KERNELS
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Prod") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, Eigen::internal::ProdReducer<type>>);
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Prod") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, \
Eigen::internal::ProdReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Prod") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, \
Eigen::internal::ProdReducer<type>>);
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Prod") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>);
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Prod") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, \
Eigen::internal::ProdReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Prod") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, \
Eigen::internal::ProdReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int32(REGISTER_GPU_KERNELS);
TF_CALL_complex64(REGISTER_GPU_KERNELS);
@ -46,18 +59,25 @@ TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Prod") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, Eigen::internal::ProdReducer<type>>);
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Prod") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int32, \
Eigen::internal::ProdReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Prod") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int64, \
Eigen::internal::ProdReducer<type>>);
REGISTER_SYCL_KERNELS(int32);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
#undef REGISTER_SYCL_KERNELS
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Sum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, Eigen::internal::SumReducer<type>>);
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Sum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, Eigen::internal::SumReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Sum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Sum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::SumReducer<type>>);
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Sum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, Eigen::internal::SumReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Sum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
@ -53,19 +66,35 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("input")
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>);
ReductionOp<CPUDevice, int32, int32, Eigen::internal::SumReducer<int32>>);
REGISTER_KERNEL_BUILDER(
Name("Sum")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("Tidx")
.HostMemory("input")
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Sum") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, Eigen::internal::SumReducer<type>>);
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Sum") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int32, \
Eigen::internal::SumReducer<type>>); \
REGISTER_KERNEL_BUILDER(Name("Sum") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<SYCLDevice, type, int64, \
Eigen::internal::SumReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
@ -77,8 +106,17 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("input")
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>);
ReductionOp<CPUDevice, int32, int32, Eigen::internal::SumReducer<int32>>);
REGISTER_KERNEL_BUILDER(
Name("Sum")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.TypeConstraint<int64>("Tidx")
.HostMemory("input")
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
#undef REGISTER_SYCL_KERNELS
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -163,6 +163,13 @@ class SumReductionTest(BaseReductionTest):
reduction_axes = tuple(reduction_axes)
return np.sum(x, axis=reduction_axes, keepdims=keep_dims)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True) as sess:
v = math_ops.reduce_sum([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = sess.run(v)
self.assertAllEqual(tf_v, 0)
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@ -193,6 +200,7 @@ class SumReductionTest(BaseReductionTest):
tf_out_mean = sess.run(tf_mean)
self.assertAllClose(tf_out_mean, 1.)
def testFloat32(self):
for rank in range(1, _MAX_RANK + 1):
np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
@ -369,6 +377,13 @@ class MeanReductionTest(BaseReductionTest):
return np_sum // count
return np_sum / count
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True) as sess:
v = math_ops.reduce_mean([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = sess.run(v)
self.assertAllEqual(tf_v, 0)
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@ -435,6 +450,13 @@ class ProdReductionTest(BaseReductionTest):
reduction_axes = tuple(reduction_axes)
return np.prod(x, axis=reduction_axes, keepdims=keep_dims)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True) as sess:
v = math_ops.reduce_prod([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = sess.run(v)
self.assertAllEqual(tf_v, 0)
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@ -531,6 +553,13 @@ class MinReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True) as sess:
v = math_ops.reduce_min([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = sess.run(v)
self.assertAllEqual(tf_v, 0)
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@ -637,6 +666,13 @@ class MaxReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True) as sess:
v = math_ops.reduce_max([0, 0], constant_op.constant(0, dtype=dtype))
tf_v = sess.run(v)
self.assertAllEqual(tf_v, 0)
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@ -757,6 +793,14 @@ class AllReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True) as sess:
v = math_ops.reduce_all([True, True],
constant_op.constant(0, dtype=dtype))
tf_v = sess.run(v)
self.assertAllEqual(tf_v, True)
def testAll3D(self):
# Create a 3D array of bools and reduce across all possible
# dimensions
@ -798,6 +842,14 @@ class AnyReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
with self.test_session(use_gpu=True) as sess:
v = math_ops.reduce_any([True, True],
constant_op.constant(0, dtype=dtype))
tf_v = sess.run(v)
self.assertAllEqual(tf_v, True)
def testAll3D(self):
# Create a 3D array of bools and reduce across all possible
# dimensions