Let tf.constant_initializer accept tuples as intiail values
following CL/180416775 PiperOrigin-RevId: 180561696
This commit is contained in:
parent
dfac1ba6e6
commit
135db42bc2
@ -147,6 +147,18 @@ class ConstantInitializersTest(test.TestCase):
|
||||
self.assertEqual(x.dtype.base_dtype, dtypes.int32)
|
||||
self.assertAllEqual(x.eval(), 7 * np.ones(shape, dtype=np.int32))
|
||||
|
||||
def testConstantTupleInitializer(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
shape = [3]
|
||||
x = variable_scope.get_variable(
|
||||
"x",
|
||||
shape=shape,
|
||||
dtype=dtypes.int32,
|
||||
initializer=init_ops.constant_initializer((10, 20, 30)))
|
||||
x.initializer.run()
|
||||
self.assertEqual(x.dtype.base_dtype, dtypes.int32)
|
||||
self.assertAllEqual(x.eval(), [10, 20, 30])
|
||||
|
||||
def _testNDimConstantInitializer(self, name, value, shape, expected):
|
||||
with self.test_session(use_gpu=True):
|
||||
init = init_ops.constant_initializer(value, dtype=dtypes.int32)
|
||||
|
@ -130,9 +130,9 @@ class Constant(Initializer):
|
||||
tensor shape, the initializer will raise a `ValueError`.
|
||||
|
||||
Args:
|
||||
value: A Python scalar, list of values, or a N-dimensional numpy array. All
|
||||
elements of the initialized variable will be set to the corresponding
|
||||
value in the `value` argument.
|
||||
value: A Python scalar, list or tuple of values, or a N-dimensional numpy
|
||||
array. All elements of the initialized variable will be set to the
|
||||
corresponding value in the `value` argument.
|
||||
dtype: The data type.
|
||||
verify_shape: Boolean that enables verification of the shape of `value`. If
|
||||
`True`, the initializer will throw an error if the shape of `value` is not
|
||||
@ -192,10 +192,10 @@ class Constant(Initializer):
|
||||
"""
|
||||
|
||||
def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
|
||||
if not (np.isscalar(value) or isinstance(value, (list, np.ndarray))):
|
||||
if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
|
||||
raise TypeError(
|
||||
"Invalid type for initial value: %s (expected Python scalar, list of "
|
||||
"values, or numpy.ndarray)." % type(value))
|
||||
"Invalid type for initial value: %s (expected Python scalar, list or "
|
||||
"tuple of values, or numpy.ndarray)." % type(value))
|
||||
|
||||
self.value = value
|
||||
self.dtype = dtypes.as_dtype(dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user