[OpenCL] Fixes double memcpy bug (#151) (#12173)

* [OpenCL] Fixes double memcpy bug (#151)

As the debg CopyOp is called on a Tensor without type, we need to use
the DataType enum to get type information, and use this to pass the type
on to Eigen. This is a workaround Eigen's need to have a type when
calling memcpy. If the Eigen memcpy can be provided without a type
requirement, then the memcpy in sycl_util is unnecessary.

* Acts on feedback from: 32cb12a900 (r132496277)
This commit is contained in:
Luke Iwanski 2017-08-11 00:17:40 +01:00 committed by Rasmus Munk Larsen
parent 92111fdd1a
commit e1e81d9ba9
2 changed files with 50 additions and 12 deletions

View File

@ -21,17 +21,60 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#include "tensorflow/core/common_runtime/device.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
// For DMA helper
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
inline void* GetBase(const Tensor* src) {
return const_cast<void*>(DMAHelper::base(src));
inline void const* GetBase(const Tensor* src) { return DMAHelper::base(src); }
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
inline void SYCLmemcpy(Eigen::SyclDevice const& device,
Tensor const& src_tensor, Tensor* dst_tensor) {
const size_t size = src_tensor.TotalBytes();
void* dst_ptr = GetBase(dst_tensor);
void const* src_ptr = GetBase(&src_tensor);
#define COPY_WITH_TYPE(T) \
device.memcpy(dst_ptr, static_cast<T const*>(src_ptr), size);
switch (src_tensor.dtype()) {
case DT_COMPLEX128:
COPY_WITH_TYPE(cl::sycl::cl_ulong2);
break;
case DT_DOUBLE:
case DT_COMPLEX64:
case DT_INT64:
COPY_WITH_TYPE(cl::sycl::cl_ulong);
break;
case DT_FLOAT:
case DT_INT32:
case DT_QINT32:
COPY_WITH_TYPE(cl::sycl::cl_uint);
break;
case DT_INT16:
case DT_UINT16:
case DT_BFLOAT16:
case DT_QINT16:
case DT_QUINT16:
case DT_HALF:
COPY_WITH_TYPE(cl::sycl::cl_ushort);
break;
case DT_BOOL:
COPY_WITH_TYPE(bool);
break;
case DT_UINT8:
case DT_INT8:
case DT_QINT8:
case DT_QUINT8:
COPY_WITH_TYPE(cl::sycl::cl_uchar);
break;
default:
LOG(FATAL) << "Unknown data type " << src_tensor.dtype();
break;
}
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
#undef COPY_WITH_TYPE
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_

View File

@ -94,12 +94,7 @@ class CopyOp : public OpKernel {
!context->input_alloc_attr(0).on_host();
if (off_host_input) {
auto size = src_tensor.NumElements() * sizeof(src_tensor.dtype());
auto dst_ptr = GetBase(copied_tensor);
auto src_ptr = GetBase(&src_tensor);
typedef decltype(src_tensor.dtype()) ttype;
context->eigen_sycl_device().memcpy(
dst_ptr, static_cast<const ttype*>(src_ptr), size);
SYCLmemcpy(context->eigen_sycl_device(), src_tensor, copied_tensor);
} else {
*copied_tensor = tensor::DeepCopy(src_tensor);
}