Add Flatten to core layers.
PiperOrigin-RevId: 168254118
This commit is contained in:
parent
a6223c01a6
commit
80ed8afc02
tensorflow
contrib
python
tools/api/golden
@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase):
|
||||
array_ops.placeholder(dtypes.float32, (5, None)),
|
||||
array_ops.placeholder(dtypes.float32, (5, 1)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
|
||||
with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'):
|
||||
conditioning_utils.condition_tensor(
|
||||
array_ops.placeholder(dtypes.float32, (5, 2)),
|
||||
array_ops.placeholder(dtypes.float32, (5)))
|
||||
|
@ -1435,30 +1435,7 @@ def flatten(inputs,
|
||||
"""
|
||||
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
inputs_rank = inputs.get_shape().ndims
|
||||
if (inputs_rank is None) or (inputs_rank < 2):
|
||||
raise ValueError('Inputs must have a least 2 dimensions.')
|
||||
|
||||
inputs_shape = array_ops.shape(inputs)
|
||||
|
||||
batch_dim = array_ops.slice(inputs_shape, [0], [1])
|
||||
spatial_dims = array_ops.slice(inputs_shape, [1], [inputs_rank - 1])
|
||||
|
||||
flat_spatial_dim = math_ops.reduce_prod(spatial_dims)
|
||||
flat_spatial_dim = array_ops.expand_dims(flat_spatial_dim, 0)
|
||||
flat_shape = array_ops.concat([batch_dim, flat_spatial_dim], 0)
|
||||
|
||||
outputs = array_ops.reshape(inputs, flat_shape)
|
||||
|
||||
# Attempt to propagate shape information, if it is defined.
|
||||
input_shape = inputs.get_shape().as_list()
|
||||
batch_dim, spatial_dims = input_shape[0], input_shape[1:]
|
||||
if all(spatial_dims):
|
||||
outputs.set_shape([batch_dim,
|
||||
functools.reduce(lambda x, y: x * y, spatial_dims)])
|
||||
else:
|
||||
outputs.set_shape([batch_dim, None])
|
||||
|
||||
outputs = core_layers.flatten(inputs)
|
||||
return utils.collect_named_outputs(outputs_collections, sc, outputs)
|
||||
|
||||
|
||||
|
@ -1399,7 +1399,7 @@ class FlattenTest(test.TestCase):
|
||||
inputs = array_ops.placeholder(dtype=dtypes.float32)
|
||||
inputs.set_shape(tensor_shape.TensorShape((5,)))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'must have a least 2 dimensions'):
|
||||
'incompatible with the layer'):
|
||||
_layers.flatten(inputs)
|
||||
|
||||
def testUnknownLastDim(self):
|
||||
|
@ -456,7 +456,7 @@ class Permute(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class Flatten(Layer):
|
||||
class Flatten(tf_core_layers.Flatten, Layer):
|
||||
"""Flattens the input. Does not affect the batch size.
|
||||
|
||||
Example:
|
||||
@ -472,26 +472,7 @@ class Flatten(Layer):
|
||||
# now: model.output_shape == (None, 65536)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(Flatten, self).__init__(**kwargs)
|
||||
self.input_spec = InputSpec(min_ndim=3)
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if not all(input_shape[1:]):
|
||||
raise ValueError('The shape of the input to "Flatten" '
|
||||
'is not fully defined '
|
||||
'(got ' + str(input_shape[1:]) + '. '
|
||||
'Make sure to pass a complete "input_shape" '
|
||||
'or "batch_input_shape" argument to the first '
|
||||
'layer in your model.')
|
||||
return tensor_shape.TensorShape([input_shape[0], np.prod(input_shape[1:])])
|
||||
|
||||
def call(self, inputs):
|
||||
outputs = K.batch_flatten(inputs)
|
||||
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
|
||||
return outputs
|
||||
pass
|
||||
|
||||
|
||||
class RepeatVector(Layer):
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import standard_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
@ -337,6 +338,67 @@ def dropout(inputs,
|
||||
return layer.apply(inputs, training=training)
|
||||
|
||||
|
||||
class Flatten(base.Layer):
|
||||
"""Flattens an input tensor while preserving the batch axis (axis 0).
|
||||
|
||||
Examples:
|
||||
|
||||
```
|
||||
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
|
||||
y = Flatten()(x)
|
||||
# now `y` has shape `(None, 16)`
|
||||
|
||||
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
|
||||
y = Flatten()(x)
|
||||
# now `y` has shape `(None, None)`
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(Flatten, self).__init__(**kwargs)
|
||||
self.input_spec = base.InputSpec(min_ndim=2)
|
||||
|
||||
def call(self, inputs):
|
||||
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
|
||||
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
|
||||
return outputs
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
output_shape = [input_shape[0]]
|
||||
if all(input_shape[1:]):
|
||||
output_shape += [np.prod(input_shape[1:])]
|
||||
else:
|
||||
output_shape += [None]
|
||||
return tensor_shape.TensorShape(output_shape)
|
||||
|
||||
|
||||
def flatten(inputs, name=None):
|
||||
"""Flattens an input tensor while preserving the batch axis (axis 0).
|
||||
|
||||
Arguments:
|
||||
inputs: Tensor input.
|
||||
name: The name of the layer (string).
|
||||
|
||||
Returns:
|
||||
Reshaped tensor.
|
||||
|
||||
Examples:
|
||||
|
||||
```
|
||||
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
|
||||
y = flatten(x)
|
||||
# now `y` has shape `(None, 16)`
|
||||
|
||||
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
|
||||
y = flatten(x)
|
||||
# now `y` has shape `(None, None)`
|
||||
```
|
||||
"""
|
||||
layer = Flatten(name=name)
|
||||
return layer.apply(inputs)
|
||||
|
||||
|
||||
# Aliases
|
||||
|
||||
FullyConnected = Dense
|
||||
|
@ -391,5 +391,56 @@ class DropoutTest(test.TestCase):
|
||||
self.assertAllClose(np.ones((5, 5)), np_output)
|
||||
|
||||
|
||||
class FlattenTest(test.TestCase):
|
||||
|
||||
def testCreateFlatten(self):
|
||||
with self.test_session() as sess:
|
||||
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
||||
y = core_layers.Flatten()(x)
|
||||
np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
|
||||
self.assertEqual(list(np_output.shape), [3, 6])
|
||||
self.assertEqual(y.get_shape().as_list(), [None, 6])
|
||||
|
||||
x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
|
||||
y = core_layers.Flatten()(x)
|
||||
np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))})
|
||||
self.assertEqual(list(np_output.shape), [1, 12])
|
||||
self.assertEqual(y.get_shape().as_list(), [1, 12])
|
||||
|
||||
def testComputeShape(self):
|
||||
shape = core_layers.Flatten()._compute_output_shape((1, 2, 3, 2))
|
||||
self.assertEqual(shape.as_list(), [1, 12])
|
||||
|
||||
shape = core_layers.Flatten()._compute_output_shape((None, 3, 2))
|
||||
self.assertEqual(shape.as_list(), [None, 6])
|
||||
|
||||
shape = core_layers.Flatten()._compute_output_shape((None, 3, None))
|
||||
self.assertEqual(shape.as_list(), [None, None])
|
||||
|
||||
def testFunctionalFlatten(self):
|
||||
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
||||
y = core_layers.flatten(x, name='flatten')
|
||||
self.assertEqual(y.get_shape().as_list(), [None, 6])
|
||||
|
||||
def testFlattenValueError(self):
|
||||
x = array_ops.placeholder(shape=(None,), dtype='float32')
|
||||
with self.assertRaises(ValueError):
|
||||
core_layers.Flatten()(x)
|
||||
|
||||
def testFlattenUnknownAxes(self):
|
||||
with self.test_session() as sess:
|
||||
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
|
||||
y = core_layers.Flatten()(x)
|
||||
np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
|
||||
self.assertEqual(list(np_output.shape), [5, 6])
|
||||
self.assertEqual(y.get_shape().as_list(), [5, None])
|
||||
|
||||
x = array_ops.placeholder(shape=(5, None, 2), dtype='float32')
|
||||
y = core_layers.Flatten()(x)
|
||||
np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))})
|
||||
self.assertEqual(list(np_output.shape), [5, 6])
|
||||
self.assertEqual(y.get_shape().as_list(), [5, None])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
@@Dense
|
||||
@@Dropout
|
||||
@@Flatten
|
||||
@@Conv1D
|
||||
@@Conv2D
|
||||
@@Conv3D
|
||||
@ -39,6 +40,7 @@
|
||||
|
||||
@@dense
|
||||
@@dropout
|
||||
@@flatten
|
||||
@@conv1d
|
||||
@@conv2d
|
||||
@@conv3d
|
||||
@ -71,9 +73,11 @@ from tensorflow.python.layers.base import InputSpec
|
||||
# Core layers.
|
||||
from tensorflow.python.layers.core import Dense
|
||||
from tensorflow.python.layers.core import Dropout
|
||||
from tensorflow.python.layers.core import Flatten
|
||||
|
||||
from tensorflow.python.layers.core import dense
|
||||
from tensorflow.python.layers.core import dropout
|
||||
from tensorflow.python.layers.core import flatten
|
||||
|
||||
# Convolutional layers.
|
||||
from tensorflow.python.layers.convolutional import SeparableConv2D
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.Flatten"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
|
118
tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
Normal file
118
tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
Normal file
@ -0,0 +1,118 @@
|
||||
path: "tensorflow.layers.Flatten"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
|
||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "scope_name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -44,6 +44,10 @@ tf_module {
|
||||
name: "Dropout"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Flatten"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "InputSpec"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -120,6 +124,10 @@ tf_module {
|
||||
name: "dropout"
|
||||
argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "flatten"
|
||||
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "max_pooling1d"
|
||||
argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user