Generate error (instead of segfault) when trying to copy string tensor

to GPU in EagerTensor constructor.

PiperOrigin-RevId: 168457320
This commit is contained in:
A. Unique TensorFlower 2017-09-12 15:23:40 -07:00 committed by TensorFlower Gardener
parent 655f26fc70
commit 00c865566f
3 changed files with 26 additions and 3 deletions

View File

@ -222,6 +222,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
return nullptr;
}
tensorflow::Tensor* src = &(h->t);
if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Can't copy Tensor with type ",
tensorflow::DataTypeString(src->dtype()),
" to device ", DeviceName(dstd), ".")
.c_str());
return nullptr;
}
if (src_cpu) {
tensorflow::Tensor dst(
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),

View File

@ -92,13 +92,15 @@ py_library(
srcs_version = "PY2AND3",
)
py_test(
cuda_py_test(
name = "tensor_test",
srcs = ["tensor_test.py"],
srcs_version = "PY2AND3",
deps = [
additional_deps = [
":context",
":tensor",
":test",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//third_party/py/numpy",
],
)

View File

@ -20,9 +20,12 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import tensor
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -136,6 +139,15 @@ class TFETensorTest(test_util.TensorFlowTestCase):
t_np = t.numpy()
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
def testStringTensorOnGPU(self):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with ops.device("/device:GPU:0"):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Can't copy Tensor with type string to device"):
tensor.Tensor("test string")
if __name__ == "__main__":
test.main()