Let tf.constant_initializer accept tuples as intiail values
following CL/180416775 PiperOrigin-RevId: 180561696
This commit is contained in:
parent
dfac1ba6e6
commit
135db42bc2
@ -147,6 +147,18 @@ class ConstantInitializersTest(test.TestCase):
|
|||||||
self.assertEqual(x.dtype.base_dtype, dtypes.int32)
|
self.assertEqual(x.dtype.base_dtype, dtypes.int32)
|
||||||
self.assertAllEqual(x.eval(), 7 * np.ones(shape, dtype=np.int32))
|
self.assertAllEqual(x.eval(), 7 * np.ones(shape, dtype=np.int32))
|
||||||
|
|
||||||
|
def testConstantTupleInitializer(self):
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
shape = [3]
|
||||||
|
x = variable_scope.get_variable(
|
||||||
|
"x",
|
||||||
|
shape=shape,
|
||||||
|
dtype=dtypes.int32,
|
||||||
|
initializer=init_ops.constant_initializer((10, 20, 30)))
|
||||||
|
x.initializer.run()
|
||||||
|
self.assertEqual(x.dtype.base_dtype, dtypes.int32)
|
||||||
|
self.assertAllEqual(x.eval(), [10, 20, 30])
|
||||||
|
|
||||||
def _testNDimConstantInitializer(self, name, value, shape, expected):
|
def _testNDimConstantInitializer(self, name, value, shape, expected):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
init = init_ops.constant_initializer(value, dtype=dtypes.int32)
|
init = init_ops.constant_initializer(value, dtype=dtypes.int32)
|
||||||
|
@ -130,9 +130,9 @@ class Constant(Initializer):
|
|||||||
tensor shape, the initializer will raise a `ValueError`.
|
tensor shape, the initializer will raise a `ValueError`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: A Python scalar, list of values, or a N-dimensional numpy array. All
|
value: A Python scalar, list or tuple of values, or a N-dimensional numpy
|
||||||
elements of the initialized variable will be set to the corresponding
|
array. All elements of the initialized variable will be set to the
|
||||||
value in the `value` argument.
|
corresponding value in the `value` argument.
|
||||||
dtype: The data type.
|
dtype: The data type.
|
||||||
verify_shape: Boolean that enables verification of the shape of `value`. If
|
verify_shape: Boolean that enables verification of the shape of `value`. If
|
||||||
`True`, the initializer will throw an error if the shape of `value` is not
|
`True`, the initializer will throw an error if the shape of `value` is not
|
||||||
@ -192,10 +192,10 @@ class Constant(Initializer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
|
def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
|
||||||
if not (np.isscalar(value) or isinstance(value, (list, np.ndarray))):
|
if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Invalid type for initial value: %s (expected Python scalar, list of "
|
"Invalid type for initial value: %s (expected Python scalar, list or "
|
||||||
"values, or numpy.ndarray)." % type(value))
|
"tuple of values, or numpy.ndarray)." % type(value))
|
||||||
|
|
||||||
self.value = value
|
self.value = value
|
||||||
self.dtype = dtypes.as_dtype(dtype)
|
self.dtype = dtypes.as_dtype(dtype)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user