[OpenCL] Fixes Split op (#10322)
* [OpenCL] Fixes Split op Split should alway go through SYCL device * [OpenCL] Removes half from registred types
This commit is contained in:
parent
9634414007
commit
95d90ab2e0
@ -50,16 +50,12 @@ void Split<Eigen::SyclDevice, T>::operator()(
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
|
||||
if (output.size() < 131072) {
|
||||
output = input.slice(slice_indices, slice_sizes);
|
||||
} else {
|
||||
output.device(d) = input.slice(slice_indices, slice_sizes);
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS)
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_SYCL_KERNELS);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace functor
|
||||
|
@ -247,7 +247,6 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
||||
template <typename T>
|
||||
class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
||||
public:
|
||||
@ -312,7 +311,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#define REGISTER_SPLIT(type) \
|
||||
@ -351,7 +349,7 @@ TF_CALL_complex128(REGISTER_GPU);
|
||||
.HostMemory("split_dim"), \
|
||||
SplitOpSYCL<type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
|
||||
#undef REGISTER_SYCL
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
Loading…
Reference in New Issue
Block a user