Generate error (instead of segfault) when trying to copy string tensor
to GPU in EagerTensor constructor. PiperOrigin-RevId: 168457320
This commit is contained in:
parent
655f26fc70
commit
00c865566f
@ -222,6 +222,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::Tensor* src = &(h->t);
|
tensorflow::Tensor* src = &(h->t);
|
||||||
|
if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) {
|
||||||
|
TF_SetStatus(
|
||||||
|
status, TF_INVALID_ARGUMENT,
|
||||||
|
tensorflow::strings::StrCat("Can't copy Tensor with type ",
|
||||||
|
tensorflow::DataTypeString(src->dtype()),
|
||||||
|
" to device ", DeviceName(dstd), ".")
|
||||||
|
.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
if (src_cpu) {
|
if (src_cpu) {
|
||||||
tensorflow::Tensor dst(
|
tensorflow::Tensor dst(
|
||||||
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
||||||
|
@ -92,13 +92,15 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
cuda_py_test(
|
||||||
name = "tensor_test",
|
name = "tensor_test",
|
||||||
srcs = ["tensor_test.py"],
|
srcs = ["tensor_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
additional_deps = [
|
||||||
deps = [
|
":context",
|
||||||
":tensor",
|
":tensor",
|
||||||
":test",
|
":test",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -20,9 +20,12 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tensor
|
from tensorflow.python.eager import tensor
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
@ -136,6 +139,15 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
t_np = t.numpy()
|
t_np = t.numpy()
|
||||||
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
|
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
|
||||||
|
|
||||||
|
def testStringTensorOnGPU(self):
|
||||||
|
if not context.context().num_gpus():
|
||||||
|
self.skipTest("No GPUs found")
|
||||||
|
with ops.device("/device:GPU:0"):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"Can't copy Tensor with type string to device"):
|
||||||
|
tensor.Tensor("test string")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user