In python function to convert a shape to a tensor, support converting to

an int64 tensor if a dimension of the shape is too large for int32.
Change: 153385458
This commit is contained in:
A. Unique TensorFlower 2017-04-17 13:03:06 -08:00 committed by TensorFlower Gardener
parent ea910532bc
commit af36579b63
3 changed files with 43 additions and 2 deletions

View File

@ -129,14 +129,24 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None,
if not s.is_fully_defined():
raise ValueError(
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
s_list = s.as_list()
int64_value = 0
for dim in s_list:
if dim >= 2**31:
int64_value = dim
break
if dtype is not None:
if dtype not in (dtypes.int32, dtypes.int64):
raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
if dtype == dtypes.int32 and int64_value:
raise ValueError("Cannot convert a TensorShape to dtype int32; "
"a dimension is too large (%s)" % int64_value)
else:
dtype = dtypes.int32
dtype = dtypes.int64 if int64_value else dtypes.int32
if name is None:
name = "shape_as_tensor"
return constant(s.as_list(), dtype=dtype, name=name)
return constant(s_list, dtype=dtype, name=name)
ops.register_tensor_conversion_function(
tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)

View File

@ -230,6 +230,29 @@ class AsTensorTest(test.TestCase):
self.assertEqual(dtypes_lib.int32, x.dtype)
self.assertAllEqual([1, 2, 3], x.eval())
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31-1, 2, 3]))
self.assertEqual(dtypes_lib.int32, x.dtype)
self.assertAllEqual([2**31-1, 2, 3], x.eval())
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31-1, 2, 3]),
dtype=dtypes_lib.int32)
self.assertEqual(dtypes_lib.int32, x.dtype)
self.assertAllEqual([2**31-1, 2, 3], x.eval())
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]))
self.assertEqual(dtypes_lib.int64, x.dtype)
self.assertAllEqual([2**31, 2, 3], x.eval())
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]),
dtype=dtypes_lib.int64)
self.assertEqual(dtypes_lib.int64, x.dtype)
self.assertAllEqual([2**31, 2, 3], x.eval())
with self.assertRaisesRegexp(
ValueError, "a dimension is too large .2147483648."):
x = ops.convert_to_tensor(tensor_shape.TensorShape([2**31, 2, 3]),
dtype=dtypes_lib.int32)
x = ops.convert_to_tensor(
tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes_lib.int64)
self.assertEqual(dtypes_lib.int64, x.dtype)

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -157,6 +158,13 @@ class TruncatedNormalTest(test.TestCase):
print("std(x)", np.std(x), abs(np.std(x) / stddev - 0.85))
self.assertTrue(abs(np.std(x) / stddev - 0.85) < 0.04)
def testLargeShape(self):
with self.test_session(use_gpu=True):
v = variables.Variable(
array_ops.zeros(dtype=dtypes.float32, shape=[2**33, 1]))
n = random_ops.truncated_normal(v.shape)
self.assertEqual([8589934592, 1], n.shape.as_list())
def testNoCSE(self):
with self.test_session(use_gpu=True):
shape = [2, 3, 4]