From 135db42bc234153f0ea212fbcf0d75f1040712d7 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Tue, 2 Jan 2018 10:24:55 -0800 Subject: [PATCH] Let tf.constant_initializer accept tuples as intiail values following CL/180416775 PiperOrigin-RevId: 180561696 --- tensorflow/python/kernel_tests/init_ops_test.py | 12 ++++++++++++ tensorflow/python/ops/init_ops.py | 12 ++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 9f4590a6c60..19a7d2f9d51 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -147,6 +147,18 @@ class ConstantInitializersTest(test.TestCase): self.assertEqual(x.dtype.base_dtype, dtypes.int32) self.assertAllEqual(x.eval(), 7 * np.ones(shape, dtype=np.int32)) + def testConstantTupleInitializer(self): + with self.test_session(use_gpu=True): + shape = [3] + x = variable_scope.get_variable( + "x", + shape=shape, + dtype=dtypes.int32, + initializer=init_ops.constant_initializer((10, 20, 30))) + x.initializer.run() + self.assertEqual(x.dtype.base_dtype, dtypes.int32) + self.assertAllEqual(x.eval(), [10, 20, 30]) + def _testNDimConstantInitializer(self, name, value, shape, expected): with self.test_session(use_gpu=True): init = init_ops.constant_initializer(value, dtype=dtypes.int32) diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 1279de331ae..5dc43d65b95 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -130,9 +130,9 @@ class Constant(Initializer): tensor shape, the initializer will raise a `ValueError`. Args: - value: A Python scalar, list of values, or a N-dimensional numpy array. All - elements of the initialized variable will be set to the corresponding - value in the `value` argument. + value: A Python scalar, list or tuple of values, or a N-dimensional numpy + array. All elements of the initialized variable will be set to the + corresponding value in the `value` argument. dtype: The data type. verify_shape: Boolean that enables verification of the shape of `value`. If `True`, the initializer will throw an error if the shape of `value` is not @@ -192,10 +192,10 @@ class Constant(Initializer): """ def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False): - if not (np.isscalar(value) or isinstance(value, (list, np.ndarray))): + if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))): raise TypeError( - "Invalid type for initial value: %s (expected Python scalar, list of " - "values, or numpy.ndarray)." % type(value)) + "Invalid type for initial value: %s (expected Python scalar, list or " + "tuple of values, or numpy.ndarray)." % type(value)) self.value = value self.dtype = dtypes.as_dtype(dtype)