Let tf.constant_initializer accept tuples as intiail values

following CL/180416775

PiperOrigin-RevId: 180561696
This commit is contained in:
Shanqing Cai 2018-01-02 10:24:55 -08:00 committed by TensorFlower Gardener
parent dfac1ba6e6
commit 135db42bc2
2 changed files with 18 additions and 6 deletions

View File

@ -147,6 +147,18 @@ class ConstantInitializersTest(test.TestCase):
self.assertEqual(x.dtype.base_dtype, dtypes.int32)
self.assertAllEqual(x.eval(), 7 * np.ones(shape, dtype=np.int32))
def testConstantTupleInitializer(self):
with self.test_session(use_gpu=True):
shape = [3]
x = variable_scope.get_variable(
"x",
shape=shape,
dtype=dtypes.int32,
initializer=init_ops.constant_initializer((10, 20, 30)))
x.initializer.run()
self.assertEqual(x.dtype.base_dtype, dtypes.int32)
self.assertAllEqual(x.eval(), [10, 20, 30])
def _testNDimConstantInitializer(self, name, value, shape, expected):
with self.test_session(use_gpu=True):
init = init_ops.constant_initializer(value, dtype=dtypes.int32)

View File

@ -130,9 +130,9 @@ class Constant(Initializer):
tensor shape, the initializer will raise a `ValueError`.
Args:
value: A Python scalar, list of values, or a N-dimensional numpy array. All
elements of the initialized variable will be set to the corresponding
value in the `value` argument.
value: A Python scalar, list or tuple of values, or a N-dimensional numpy
array. All elements of the initialized variable will be set to the
corresponding value in the `value` argument.
dtype: The data type.
verify_shape: Boolean that enables verification of the shape of `value`. If
`True`, the initializer will throw an error if the shape of `value` is not
@ -192,10 +192,10 @@ class Constant(Initializer):
"""
def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
if not (np.isscalar(value) or isinstance(value, (list, np.ndarray))):
if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
raise TypeError(
"Invalid type for initial value: %s (expected Python scalar, list of "
"values, or numpy.ndarray)." % type(value))
"Invalid type for initial value: %s (expected Python scalar, list or "
"tuple of values, or numpy.ndarray)." % type(value))
self.value = value
self.dtype = dtypes.as_dtype(dtype)