* [OpenCL] Fixes double memcpy bug (#151)
As the debg CopyOp is called on a Tensor without type, we need to use
the DataType enum to get type information, and use this to pass the type
on to Eigen. This is a workaround Eigen's need to have a type when
calling memcpy. If the Eigen memcpy can be provided without a type
requirement, then the memcpy in sycl_util is unnecessary.
* Acts on feedback from: 32cb12a900 (r132496277)
This commit is contained in:
parent
92111fdd1a
commit
e1e81d9ba9
@ -21,17 +21,60 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
// For DMA helper
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
inline void* GetBase(const Tensor* src) {
|
||||
return const_cast<void*>(DMAHelper::base(src));
|
||||
inline void const* GetBase(const Tensor* src) { return DMAHelper::base(src); }
|
||||
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
|
||||
|
||||
inline void SYCLmemcpy(Eigen::SyclDevice const& device,
|
||||
Tensor const& src_tensor, Tensor* dst_tensor) {
|
||||
const size_t size = src_tensor.TotalBytes();
|
||||
void* dst_ptr = GetBase(dst_tensor);
|
||||
void const* src_ptr = GetBase(&src_tensor);
|
||||
|
||||
#define COPY_WITH_TYPE(T) \
|
||||
device.memcpy(dst_ptr, static_cast<T const*>(src_ptr), size);
|
||||
switch (src_tensor.dtype()) {
|
||||
case DT_COMPLEX128:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_ulong2);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
case DT_COMPLEX64:
|
||||
case DT_INT64:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_ulong);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
case DT_INT32:
|
||||
case DT_QINT32:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_uint);
|
||||
break;
|
||||
case DT_INT16:
|
||||
case DT_UINT16:
|
||||
case DT_BFLOAT16:
|
||||
case DT_QINT16:
|
||||
case DT_QUINT16:
|
||||
case DT_HALF:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_ushort);
|
||||
break;
|
||||
case DT_BOOL:
|
||||
COPY_WITH_TYPE(bool);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
case DT_INT8:
|
||||
case DT_QINT8:
|
||||
case DT_QUINT8:
|
||||
COPY_WITH_TYPE(cl::sycl::cl_uchar);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type " << src_tensor.dtype();
|
||||
break;
|
||||
}
|
||||
|
||||
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
|
||||
|
||||
#undef COPY_WITH_TYPE
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
|
||||
|
@ -94,12 +94,7 @@ class CopyOp : public OpKernel {
|
||||
!context->input_alloc_attr(0).on_host();
|
||||
|
||||
if (off_host_input) {
|
||||
auto size = src_tensor.NumElements() * sizeof(src_tensor.dtype());
|
||||
auto dst_ptr = GetBase(copied_tensor);
|
||||
auto src_ptr = GetBase(&src_tensor);
|
||||
typedef decltype(src_tensor.dtype()) ttype;
|
||||
context->eigen_sycl_device().memcpy(
|
||||
dst_ptr, static_cast<const ttype*>(src_ptr), size);
|
||||
SYCLmemcpy(context->eigen_sycl_device(), src_tensor, copied_tensor);
|
||||
} else {
|
||||
*copied_tensor = tensor::DeepCopy(src_tensor);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user