[OpenCL] Fixes Split op (#10322)

* [OpenCL] Fixes Split op

  Split should alway go through SYCL device

* [OpenCL] Removes half from registred types
This commit is contained in:
Luke Iwanski 2017-06-01 03:30:39 +01:00 committed by Benoit Steiner
parent 9634414007
commit 95d90ab2e0
2 changed files with 3 additions and 9 deletions

View File

@ -50,16 +50,12 @@ void Split<Eigen::SyclDevice, T>::operator()(
typename TTypes<T, 3>::ConstTensor input, typename TTypes<T, 3>::ConstTensor input,
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices, const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) { const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
if (output.size() < 131072) {
output = input.slice(slice_indices, slice_sizes);
} else {
output.device(d) = input.slice(slice_indices, slice_sizes); output.device(d) = input.slice(slice_indices, slice_sizes);
}
} }
#define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>; #define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_SYCL_KERNELS);
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
} // namespace functor } // namespace functor

View File

@ -247,7 +247,6 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
template <typename T> template <typename T>
class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> { class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
public: public:
@ -312,7 +311,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
} }
} }
}; };
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
#define REGISTER_SPLIT(type) \ #define REGISTER_SPLIT(type) \
@ -351,7 +349,7 @@ TF_CALL_complex128(REGISTER_GPU);
.HostMemory("split_dim"), \ .HostMemory("split_dim"), \
SplitOpSYCL<type>) SplitOpSYCL<type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
#undef REGISTER_SYCL #undef REGISTER_SYCL
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL