In python function to convert a shape to a tensor, support converting to
an int64 tensor if a dimension of the shape is too large for int32. Change: 153385458
This commit is contained in:
parent
ea910532bc
commit
af36579b63
@ -129,14 +129,24 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None,
|
||||
if not s.is_fully_defined():
|
||||
raise ValueError(
|
||||
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
|
||||
s_list = s.as_list()
|
||||
int64_value = 0
|
||||
for dim in s_list:
|
||||
if dim >= 2**31:
|
||||
int64_value = dim
|
||||
break
|
||||
|
||||
if dtype is not None:
|
||||
if dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
|
||||
if dtype == dtypes.int32 and int64_value:
|
||||
raise ValueError("Cannot convert a TensorShape to dtype int32; "
|
||||
"a dimension is too large (%s)" % int64_value)
|
||||
else:
|
||||
dtype = dtypes.int32
|
||||
dtype = dtypes.int64 if int64_value else dtypes.int32
|
||||
if name is None:
|
||||
name = "shape_as_tensor"
|
||||
return constant(s.as_list(), dtype=dtype, name=name)
|
||||
return constant(s_list, dtype=dtype, name=name)
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
|
||||
|
@ -230,6 +230,29 @@ class AsTensorTest(test.TestCase):
|
||||
self.assertEqual(dtypes_lib.int32, x.dtype)
|
||||
self.assertAllEqual([1, 2, 3], x.eval())
|
||||
|
||||
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31-1, 2, 3]))
|
||||
self.assertEqual(dtypes_lib.int32, x.dtype)
|
||||
self.assertAllEqual([2**31-1, 2, 3], x.eval())
|
||||
|
||||
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31-1, 2, 3]),
|
||||
dtype=dtypes_lib.int32)
|
||||
self.assertEqual(dtypes_lib.int32, x.dtype)
|
||||
self.assertAllEqual([2**31-1, 2, 3], x.eval())
|
||||
|
||||
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]))
|
||||
self.assertEqual(dtypes_lib.int64, x.dtype)
|
||||
self.assertAllEqual([2**31, 2, 3], x.eval())
|
||||
|
||||
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]),
|
||||
dtype=dtypes_lib.int64)
|
||||
self.assertEqual(dtypes_lib.int64, x.dtype)
|
||||
self.assertAllEqual([2**31, 2, 3], x.eval())
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "a dimension is too large .2147483648."):
|
||||
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]),
|
||||
dtype=dtypes_lib.int32)
|
||||
|
||||
x = ops.convert_to_tensor(
|
||||
tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes_lib.int64)
|
||||
self.assertEqual(dtypes_lib.int64, x.dtype)
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -157,6 +158,13 @@ class TruncatedNormalTest(test.TestCase):
|
||||
print("std(x)", np.std(x), abs(np.std(x) / stddev - 0.85))
|
||||
self.assertTrue(abs(np.std(x) / stddev - 0.85) < 0.04)
|
||||
|
||||
def testLargeShape(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
v = variables.Variable(
|
||||
array_ops.zeros(dtype=dtypes.float32, shape=[2**33, 1]))
|
||||
n = random_ops.truncated_normal(v.shape)
|
||||
self.assertEqual([8589934592, 1], n.shape.as_list())
|
||||
|
||||
def testNoCSE(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
shape = [2, 3, 4]
|
||||
|
Loading…
Reference in New Issue
Block a user