[OpenCL] Cleans cast operation (#10330)
* [OpenCL] Removes not needed typedef for SYCLDevice * [OpenCL] Fixes formatting * [OpenCL] use SYCLDevice for int32 cast case
This commit is contained in:
parent
f8e1cf8fa5
commit
85f9681258
@ -239,12 +239,11 @@ class SyclCastOp : public CastOpBase {
|
||||
};
|
||||
|
||||
#define REGISTER_CAST_SYCL(srctype, dsttype) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Cast") \
|
||||
.TypeConstraint<srctype>("SrcT") \
|
||||
.TypeConstraint<dsttype>("DstT") \
|
||||
REGISTER_KERNEL_BUILDER(Name("Cast") \
|
||||
.TypeConstraint<srctype>("SrcT") \
|
||||
.TypeConstraint<dsttype>("DstT") \
|
||||
.Device(DEVICE_SYCL), \
|
||||
SyclCastOp)
|
||||
|
||||
CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
|
||||
CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
|
||||
CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
|
||||
|
@ -38,7 +38,7 @@ GetGpuCastFromInt32(DataType dst_dtype) {
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
|
||||
GetSyclCastFromInt32(DataType dst_dtype) {
|
||||
CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
|
||||
CURRY_TYPES3(CAST_CASE, SYCLDevice, int32);
|
||||
return nullptr;
|
||||
}
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
@ -19,9 +19,6 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
|
||||
GetCpuCastFromInt64(DataType dst_dtype) {
|
||||
|
Loading…
Reference in New Issue
Block a user