[OpenCL] Fixes Split op (#10322)
* [OpenCL] Fixes Split op Split should alway go through SYCL device * [OpenCL] Removes half from registred types
This commit is contained in:
parent
9634414007
commit
95d90ab2e0
@ -50,16 +50,12 @@ void Split<Eigen::SyclDevice, T>::operator()(
|
|||||||
typename TTypes<T, 3>::ConstTensor input,
|
typename TTypes<T, 3>::ConstTensor input,
|
||||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
|
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
|
||||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
|
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
|
||||||
if (output.size() < 131072) {
|
|
||||||
output = input.slice(slice_indices, slice_sizes);
|
|
||||||
} else {
|
|
||||||
output.device(d) = input.slice(slice_indices, slice_sizes);
|
output.device(d) = input.slice(slice_indices, slice_sizes);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>;
|
#define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>;
|
||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS)
|
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_SYCL_KERNELS);
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
@ -247,7 +247,6 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
|||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
||||||
public:
|
public:
|
||||||
@ -312,8 +311,7 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
|
||||||
|
|
||||||
#define REGISTER_SPLIT(type) \
|
#define REGISTER_SPLIT(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("Split") \
|
REGISTER_KERNEL_BUILDER(Name("Split") \
|
||||||
@ -351,7 +349,7 @@ TF_CALL_complex128(REGISTER_GPU);
|
|||||||
.HostMemory("split_dim"), \
|
.HostMemory("split_dim"), \
|
||||||
SplitOpSYCL<type>)
|
SplitOpSYCL<type>)
|
||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL);
|
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
|
||||||
#undef REGISTER_SYCL
|
#undef REGISTER_SYCL
|
||||||
|
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
Loading…
Reference in New Issue
Block a user