In python function to convert a shape to a tensor, support converting to
an int64 tensor if a dimension of the shape is too large for int32. Change: 153385458
This commit is contained in:
parent
ea910532bc
commit
af36579b63
@ -129,14 +129,24 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None,
|
|||||||
if not s.is_fully_defined():
|
if not s.is_fully_defined():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
|
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
|
||||||
|
s_list = s.as_list()
|
||||||
|
int64_value = 0
|
||||||
|
for dim in s_list:
|
||||||
|
if dim >= 2**31:
|
||||||
|
int64_value = dim
|
||||||
|
break
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
if dtype not in (dtypes.int32, dtypes.int64):
|
if dtype not in (dtypes.int32, dtypes.int64):
|
||||||
raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
|
raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
|
||||||
|
if dtype == dtypes.int32 and int64_value:
|
||||||
|
raise ValueError("Cannot convert a TensorShape to dtype int32; "
|
||||||
|
"a dimension is too large (%s)" % int64_value)
|
||||||
else:
|
else:
|
||||||
dtype = dtypes.int32
|
dtype = dtypes.int64 if int64_value else dtypes.int32
|
||||||
if name is None:
|
if name is None:
|
||||||
name = "shape_as_tensor"
|
name = "shape_as_tensor"
|
||||||
return constant(s.as_list(), dtype=dtype, name=name)
|
return constant(s_list, dtype=dtype, name=name)
|
||||||
|
|
||||||
ops.register_tensor_conversion_function(
|
ops.register_tensor_conversion_function(
|
||||||
tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
|
tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
|
||||||
|
@ -230,6 +230,29 @@ class AsTensorTest(test.TestCase):
|
|||||||
self.assertEqual(dtypes_lib.int32, x.dtype)
|
self.assertEqual(dtypes_lib.int32, x.dtype)
|
||||||
self.assertAllEqual([1, 2, 3], x.eval())
|
self.assertAllEqual([1, 2, 3], x.eval())
|
||||||
|
|
||||||
|
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31-1, 2, 3]))
|
||||||
|
self.assertEqual(dtypes_lib.int32, x.dtype)
|
||||||
|
self.assertAllEqual([2**31-1, 2, 3], x.eval())
|
||||||
|
|
||||||
|
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31-1, 2, 3]),
|
||||||
|
dtype=dtypes_lib.int32)
|
||||||
|
self.assertEqual(dtypes_lib.int32, x.dtype)
|
||||||
|
self.assertAllEqual([2**31-1, 2, 3], x.eval())
|
||||||
|
|
||||||
|
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]))
|
||||||
|
self.assertEqual(dtypes_lib.int64, x.dtype)
|
||||||
|
self.assertAllEqual([2**31, 2, 3], x.eval())
|
||||||
|
|
||||||
|
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]),
|
||||||
|
dtype=dtypes_lib.int64)
|
||||||
|
self.assertEqual(dtypes_lib.int64, x.dtype)
|
||||||
|
self.assertAllEqual([2**31, 2, 3], x.eval())
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "a dimension is too large .2147483648."):
|
||||||
|
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]),
|
||||||
|
dtype=dtypes_lib.int32)
|
||||||
|
|
||||||
x = ops.convert_to_tensor(
|
x = ops.convert_to_tensor(
|
||||||
tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes_lib.int64)
|
tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes_lib.int64)
|
||||||
self.assertEqual(dtypes_lib.int64, x.dtype)
|
self.assertEqual(dtypes_lib.int64, x.dtype)
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -157,6 +158,13 @@ class TruncatedNormalTest(test.TestCase):
|
|||||||
print("std(x)", np.std(x), abs(np.std(x) / stddev - 0.85))
|
print("std(x)", np.std(x), abs(np.std(x) / stddev - 0.85))
|
||||||
self.assertTrue(abs(np.std(x) / stddev - 0.85) < 0.04)
|
self.assertTrue(abs(np.std(x) / stddev - 0.85) < 0.04)
|
||||||
|
|
||||||
|
def testLargeShape(self):
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
v = variables.Variable(
|
||||||
|
array_ops.zeros(dtype=dtypes.float32, shape=[2**33, 1]))
|
||||||
|
n = random_ops.truncated_normal(v.shape)
|
||||||
|
self.assertEqual([8589934592, 1], n.shape.as_list())
|
||||||
|
|
||||||
def testNoCSE(self):
|
def testNoCSE(self):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
shape = [2, 3, 4]
|
shape = [2, 3, 4]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user