Generate error (instead of segfault) when trying to copy string tensor
to GPU in EagerTensor constructor. PiperOrigin-RevId: 168457320
This commit is contained in:
parent
655f26fc70
commit
00c865566f
@ -222,6 +222,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Tensor* src = &(h->t);
|
||||
if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat("Can't copy Tensor with type ",
|
||||
tensorflow::DataTypeString(src->dtype()),
|
||||
" to device ", DeviceName(dstd), ".")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (src_cpu) {
|
||||
tensorflow::Tensor dst(
|
||||
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
||||
|
@ -92,13 +92,15 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "tensor_test",
|
||||
srcs = ["tensor_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
additional_deps = [
|
||||
":context",
|
||||
":tensor",
|
||||
":test",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -20,9 +20,12 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tensor
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
@ -136,6 +139,15 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
t_np = t.numpy()
|
||||
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
|
||||
|
||||
def testStringTensorOnGPU(self):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest("No GPUs found")
|
||||
with ops.device("/device:GPU:0"):
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Can't copy Tensor with type string to device"):
|
||||
tensor.Tensor("test string")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user