Allow creating a list from a tensor. Fix a few inconsistencies in the tensor list constructors.
PiperOrigin-RevId: 215435720
This commit is contained in:
parent
16b44d48d4
commit
8d4ef71f06
tensorflow/python/autograph
@ -24,6 +24,26 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import data_structures
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
|
||||
def _validate_list_constructor(elements, element_dtype, element_shape):
|
||||
"""Validates the inputs of tensor_list."""
|
||||
if element_dtype is not None and element_shape is not None:
|
||||
return
|
||||
if tensor_util.is_tensor(elements):
|
||||
return
|
||||
if isinstance(elements, (list, tuple)):
|
||||
if elements:
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
'element_dtype and element_shape are required when elements are'
|
||||
' empty')
|
||||
|
||||
raise ValueError(
|
||||
'unknown type for elements: {}; only Tensor, list and tuple are'
|
||||
' allowed'.format(type(elements)))
|
||||
|
||||
|
||||
def tensor_list(elements,
|
||||
@ -52,9 +72,7 @@ def tensor_list(elements,
|
||||
Raises:
|
||||
ValueError: for invalid arguments
|
||||
"""
|
||||
if not (elements or (element_dtype and element_shape)):
|
||||
raise ValueError(
|
||||
'element_dtype and element_shape are required for empty lists')
|
||||
_validate_list_constructor(elements, element_dtype, element_shape)
|
||||
if use_tensor_array:
|
||||
return data_structures.tf_tensor_array_new(elements, element_dtype,
|
||||
element_shape)
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -28,12 +30,43 @@ from tensorflow.python.platform import test
|
||||
|
||||
class SpecialFunctionsTest(test.TestCase):
|
||||
|
||||
def test_tensor_list_empty_list(self):
|
||||
l = special_functions.tensor_list([],
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.test_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
l = special_functions.tensor_list((),
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.test_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
def test_tensor_list_tensor(self):
|
||||
l = special_functions.tensor_list(
|
||||
constant_op.constant([], dtype=dtypes.int32))
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.test_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
def test_tensor_list_unsupported_initializer(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown type'):
|
||||
special_functions.tensor_list(np.array([1, 2, 3]))
|
||||
|
||||
def test_tensor_list_empty_list_no_type(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'element_dtype and element_shape are required'):
|
||||
special_functions.tensor_list([])
|
||||
|
||||
def test_tensor_list_from_elements(self):
|
||||
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
|
||||
|
||||
l = special_functions.tensor_list(elements)
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
with self.test_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_tensor_list_array_from_elements(self):
|
||||
@ -41,7 +74,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
|
||||
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
||||
sl = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
with self.test_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_stack(self):
|
||||
|
@ -106,6 +106,14 @@ def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
|
||||
|
||||
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
|
||||
"""Overload of new_list that stages a Tensor list creation."""
|
||||
if tensor_util.is_tensor(elements):
|
||||
if element_shape is not None:
|
||||
raise ValueError(
|
||||
'element shape may not be specified when creating list from tensor')
|
||||
element_shape = array_ops.shape(elements)[1:]
|
||||
l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
|
||||
return l
|
||||
|
||||
elements = tuple(ops.convert_to_tensor(el) for el in elements)
|
||||
|
||||
all_dtypes = set(el.dtype for el in elements)
|
||||
@ -115,13 +123,15 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
|
||||
raise ValueError(
|
||||
'incompatible dtype; specified: {}, inferred from {}: {}'.format(
|
||||
element_dtype, elements, inferred_dtype))
|
||||
else:
|
||||
elif all_dtypes:
|
||||
# Heterogeneous lists are ok.
|
||||
if element_dtype is not None:
|
||||
raise ValueError(
|
||||
'specified dtype {} is inconsistent with that of elements {}'.format(
|
||||
element_dtype, elements))
|
||||
inferred_dtype = dtypes.variant
|
||||
else:
|
||||
inferred_dtype = dtypes.variant
|
||||
|
||||
all_shapes = set(tuple(el.shape.as_list()) for el in elements)
|
||||
if len(all_shapes) == 1:
|
||||
@ -130,19 +140,22 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
|
||||
raise ValueError(
|
||||
'incompatible shape; specified: {}, inferred from {}: {}'.format(
|
||||
element_shape, elements, inferred_shape))
|
||||
else:
|
||||
elif all_shapes:
|
||||
# Heterogeneous lists are ok.
|
||||
if element_shape is not None:
|
||||
raise ValueError(
|
||||
'specified shape {} is inconsistent with that of elements {}'.format(
|
||||
element_shape, elements))
|
||||
inferred_shape = constant_op.constant(-1) # unknown shape, by convention
|
||||
else:
|
||||
inferred_shape = constant_op.constant(-1) # unknown shape, by convention
|
||||
|
||||
if element_dtype is None:
|
||||
element_dtype = inferred_dtype
|
||||
if element_shape is None:
|
||||
element_shape = inferred_shape
|
||||
|
||||
element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
|
||||
l = list_ops.empty_tensor_list(
|
||||
element_shape=element_shape, element_dtype=element_dtype)
|
||||
for el in elements:
|
||||
|
@ -45,6 +45,20 @@ class ListTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_empty(self):
|
||||
l = data_structures.tf_tensor_list_new([],
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [])
|
||||
|
||||
def test_tf_tensor_list_new_from_tensor(self):
|
||||
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_illegal_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
data_structures.tf_tensor_list_new([3, 4.0])
|
||||
@ -56,9 +70,8 @@ class ListTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
data_structures.tf_tensor_list_new([3, 4], element_shape=(2,))
|
||||
with self.assertRaises(ValueError):
|
||||
data_structures.tf_tensor_list_new([], element_shape=(2,))
|
||||
with self.assertRaises(ValueError):
|
||||
data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32)
|
||||
data_structures.tf_tensor_list_new(
|
||||
constant_op.constant([1, 2, 3]), element_shape=[1])
|
||||
|
||||
def test_tf_tensor_array_new(self):
|
||||
l = data_structures.tf_tensor_array_new([3, 4, 5])
|
||||
@ -141,6 +154,18 @@ class ListTest(test.TestCase):
|
||||
t = data_structures.list_stack(l, opts)
|
||||
self.assertAllEqual(sess.run(t), sess.run(initial_list))
|
||||
|
||||
def test_stack_tensor_list_empty(self):
|
||||
l = list_ops.empty_tensor_list(
|
||||
element_shape=-1,
|
||||
element_dtype=dtypes.variant)
|
||||
|
||||
opts = data_structures.ListStackOpts(
|
||||
element_dtype=dtypes.int32, original_call=None)
|
||||
|
||||
# TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
|
||||
with self.assertRaises(ValueError):
|
||||
data_structures.list_stack(l, opts)
|
||||
|
||||
def test_stack_fallback(self):
|
||||
|
||||
def dummy_function(l):
|
||||
|
Loading…
Reference in New Issue
Block a user