Allow creating a list from a tensor. Fix a few inconsistencies in the tensor list constructors.

PiperOrigin-RevId: 215435720
This commit is contained in:
Dan Moldovan 2018-10-02 12:15:36 -07:00 committed by TensorFlower Gardener
parent 16b44d48d4
commit 8d4ef71f06
4 changed files with 99 additions and 10 deletions

View File

@ -24,6 +24,26 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import data_structures
from tensorflow.python.framework import tensor_util
def _validate_list_constructor(elements, element_dtype, element_shape):
"""Validates the inputs of tensor_list."""
if element_dtype is not None and element_shape is not None:
return
if tensor_util.is_tensor(elements):
return
if isinstance(elements, (list, tuple)):
if elements:
return
else:
raise ValueError(
'element_dtype and element_shape are required when elements are'
' empty')
raise ValueError(
'unknown type for elements: {}; only Tensor, list and tuple are'
' allowed'.format(type(elements)))
def tensor_list(elements,
@ -52,9 +72,7 @@ def tensor_list(elements,
Raises:
ValueError: for invalid arguments
"""
if not (elements or (element_dtype and element_shape)):
raise ValueError(
'element_dtype and element_shape are required for empty lists')
_validate_list_constructor(elements, element_dtype, element_shape)
if use_tensor_array:
return data_structures.tf_tensor_array_new(elements, element_dtype,
element_shape)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -28,12 +30,43 @@ from tensorflow.python.platform import test
class SpecialFunctionsTest(test.TestCase):
def test_tensor_list_empty_list(self):
l = special_functions.tensor_list([],
element_dtype=dtypes.int32,
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [])
l = special_functions.tensor_list((),
element_dtype=dtypes.int32,
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [])
def test_tensor_list_tensor(self):
l = special_functions.tensor_list(
constant_op.constant([], dtype=dtypes.int32))
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [])
def test_tensor_list_unsupported_initializer(self):
with self.assertRaisesRegexp(ValueError, 'unknown type'):
special_functions.tensor_list(np.array([1, 2, 3]))
def test_tensor_list_empty_list_no_type(self):
with self.assertRaisesRegexp(
ValueError, 'element_dtype and element_shape are required'):
special_functions.tensor_list([])
def test_tensor_list_from_elements(self):
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
@ -41,7 +74,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
with self.cached_session() as sess:
with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):

View File

@ -106,6 +106,14 @@ def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
"""Overload of new_list that stages a Tensor list creation."""
if tensor_util.is_tensor(elements):
if element_shape is not None:
raise ValueError(
'element shape may not be specified when creating list from tensor')
element_shape = array_ops.shape(elements)[1:]
l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
return l
elements = tuple(ops.convert_to_tensor(el) for el in elements)
all_dtypes = set(el.dtype for el in elements)
@ -115,13 +123,15 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible dtype; specified: {}, inferred from {}: {}'.format(
element_dtype, elements, inferred_dtype))
else:
elif all_dtypes:
# Heterogeneous lists are ok.
if element_dtype is not None:
raise ValueError(
'specified dtype {} is inconsistent with that of elements {}'.format(
element_dtype, elements))
inferred_dtype = dtypes.variant
else:
inferred_dtype = dtypes.variant
all_shapes = set(tuple(el.shape.as_list()) for el in elements)
if len(all_shapes) == 1:
@ -130,19 +140,22 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible shape; specified: {}, inferred from {}: {}'.format(
element_shape, elements, inferred_shape))
else:
elif all_shapes:
# Heterogeneous lists are ok.
if element_shape is not None:
raise ValueError(
'specified shape {} is inconsistent with that of elements {}'.format(
element_shape, elements))
inferred_shape = constant_op.constant(-1) # unknown shape, by convention
else:
inferred_shape = constant_op.constant(-1) # unknown shape, by convention
if element_dtype is None:
element_dtype = inferred_dtype
if element_shape is None:
element_shape = inferred_shape
element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
l = list_ops.empty_tensor_list(
element_shape=element_shape, element_dtype=element_dtype)
for el in elements:

View File

@ -45,6 +45,20 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
def test_tf_tensor_list_new_empty(self):
l = data_structures.tf_tensor_list_new([],
element_dtype=dtypes.int32,
element_shape=())
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [])
def test_tf_tensor_list_new_from_tensor(self):
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
def test_tf_tensor_list_new_illegal_input(self):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4.0])
@ -56,9 +70,8 @@ class ListTest(test.TestCase):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4], element_shape=(2,))
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([], element_shape=(2,))
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32)
data_structures.tf_tensor_list_new(
constant_op.constant([1, 2, 3]), element_shape=[1])
def test_tf_tensor_array_new(self):
l = data_structures.tf_tensor_array_new([3, 4, 5])
@ -141,6 +154,18 @@ class ListTest(test.TestCase):
t = data_structures.list_stack(l, opts)
self.assertAllEqual(sess.run(t), sess.run(initial_list))
def test_stack_tensor_list_empty(self):
l = list_ops.empty_tensor_list(
element_shape=-1,
element_dtype=dtypes.variant)
opts = data_structures.ListStackOpts(
element_dtype=dtypes.int32, original_call=None)
# TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
with self.assertRaises(ValueError):
data_structures.list_stack(l, opts)
def test_stack_fallback(self):
def dummy_function(l):