Add Flatten to core layers.

PiperOrigin-RevId: 168254118
This commit is contained in:
Francois Chollet 2017-09-11 11:00:26 -07:00 committed by TensorFlower Gardener
parent a6223c01a6
commit 80ed8afc02
10 changed files with 249 additions and 47 deletions

View File

@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase):
array_ops.placeholder(dtypes.float32, (5, None)),
array_ops.placeholder(dtypes.float32, (5, 1)))
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (5, 2)),
array_ops.placeholder(dtypes.float32, (5)))

View File

@ -1435,30 +1435,7 @@ def flatten(inputs,
"""
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_rank = inputs.get_shape().ndims
if (inputs_rank is None) or (inputs_rank < 2):
raise ValueError('Inputs must have a least 2 dimensions.')
inputs_shape = array_ops.shape(inputs)
batch_dim = array_ops.slice(inputs_shape, [0], [1])
spatial_dims = array_ops.slice(inputs_shape, [1], [inputs_rank - 1])
flat_spatial_dim = math_ops.reduce_prod(spatial_dims)
flat_spatial_dim = array_ops.expand_dims(flat_spatial_dim, 0)
flat_shape = array_ops.concat([batch_dim, flat_spatial_dim], 0)
outputs = array_ops.reshape(inputs, flat_shape)
# Attempt to propagate shape information, if it is defined.
input_shape = inputs.get_shape().as_list()
batch_dim, spatial_dims = input_shape[0], input_shape[1:]
if all(spatial_dims):
outputs.set_shape([batch_dim,
functools.reduce(lambda x, y: x * y, spatial_dims)])
else:
outputs.set_shape([batch_dim, None])
outputs = core_layers.flatten(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)

View File

@ -1399,7 +1399,7 @@ class FlattenTest(test.TestCase):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5,)))
with self.assertRaisesRegexp(ValueError,
'must have a least 2 dimensions'):
'incompatible with the layer'):
_layers.flatten(inputs)
def testUnknownLastDim(self):

View File

@ -456,7 +456,7 @@ class Permute(Layer):
return dict(list(base_config.items()) + list(config.items()))
class Flatten(Layer):
class Flatten(tf_core_layers.Flatten, Layer):
"""Flattens the input. Does not affect the batch size.
Example:
@ -472,26 +472,7 @@ class Flatten(Layer):
# now: model.output_shape == (None, 65536)
```
"""
def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = InputSpec(min_ndim=3)
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if not all(input_shape[1:]):
raise ValueError('The shape of the input to "Flatten" '
'is not fully defined '
'(got ' + str(input_shape[1:]) + '. '
'Make sure to pass a complete "input_shape" '
'or "batch_input_shape" argument to the first '
'layer in your model.')
return tensor_shape.TensorShape([input_shape[0], np.prod(input_shape[1:])])
def call(self, inputs):
outputs = K.batch_flatten(inputs)
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
return outputs
pass
class RepeatVector(Layer):

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs
@ -337,6 +338,67 @@ def dropout(inputs,
return layer.apply(inputs, training=training)
class Flatten(base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Examples:
```
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
y = Flatten()(x)
# now `y` has shape `(None, 16)`
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
y = Flatten()(x)
# now `y` has shape `(None, None)`
```
"""
def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = base.InputSpec(min_ndim=2)
def call(self, inputs):
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
return outputs
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
output_shape = [input_shape[0]]
if all(input_shape[1:]):
output_shape += [np.prod(input_shape[1:])]
else:
output_shape += [None]
return tensor_shape.TensorShape(output_shape)
def flatten(inputs, name=None):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
name: The name of the layer (string).
Returns:
Reshaped tensor.
Examples:
```
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
y = flatten(x)
# now `y` has shape `(None, 16)`
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
y = flatten(x)
# now `y` has shape `(None, None)`
```
"""
layer = Flatten(name=name)
return layer.apply(inputs)
# Aliases
FullyConnected = Dense

View File

@ -391,5 +391,56 @@ class DropoutTest(test.TestCase):
self.assertAllClose(np.ones((5, 5)), np_output)
class FlattenTest(test.TestCase):
def testCreateFlatten(self):
with self.test_session() as sess:
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
self.assertEqual(list(np_output.shape), [3, 6])
self.assertEqual(y.get_shape().as_list(), [None, 6])
x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))})
self.assertEqual(list(np_output.shape), [1, 12])
self.assertEqual(y.get_shape().as_list(), [1, 12])
def testComputeShape(self):
shape = core_layers.Flatten()._compute_output_shape((1, 2, 3, 2))
self.assertEqual(shape.as_list(), [1, 12])
shape = core_layers.Flatten()._compute_output_shape((None, 3, 2))
self.assertEqual(shape.as_list(), [None, 6])
shape = core_layers.Flatten()._compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])
def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')
self.assertEqual(y.get_shape().as_list(), [None, 6])
def testFlattenValueError(self):
x = array_ops.placeholder(shape=(None,), dtype='float32')
with self.assertRaises(ValueError):
core_layers.Flatten()(x)
def testFlattenUnknownAxes(self):
with self.test_session() as sess:
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
self.assertEqual(list(np_output.shape), [5, 6])
self.assertEqual(y.get_shape().as_list(), [5, None])
x = array_ops.placeholder(shape=(5, None, 2), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))})
self.assertEqual(list(np_output.shape), [5, 6])
self.assertEqual(y.get_shape().as_list(), [5, None])
if __name__ == '__main__':
test.main()

View File

@ -18,6 +18,7 @@
@@Dense
@@Dropout
@@Flatten
@@Conv1D
@@Conv2D
@@Conv3D
@ -39,6 +40,7 @@
@@dense
@@dropout
@@flatten
@@conv1d
@@conv2d
@@conv3d
@ -71,9 +73,11 @@ from tensorflow.python.layers.base import InputSpec
# Core layers.
from tensorflow.python.layers.core import Dense
from tensorflow.python.layers.core import Dropout
from tensorflow.python.layers.core import Flatten
from tensorflow.python.layers.core import dense
from tensorflow.python.layers.core import dropout
from tensorflow.python.layers.core import flatten
# Convolutional layers.
from tensorflow.python.layers.convolutional import SeparableConv2D

View File

@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.Flatten"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"

View File

@ -0,0 +1,118 @@
path: "tensorflow.layers.Flatten"
tf_class {
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
name: "graph"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -44,6 +44,10 @@ tf_module {
name: "Dropout"
mtype: "<type \'type\'>"
}
member {
name: "Flatten"
mtype: "<type \'type\'>"
}
member {
name: "InputSpec"
mtype: "<type \'type\'>"
@ -120,6 +124,10 @@ tf_module {
name: "dropout"
argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "flatten"
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "max_pooling1d"
argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], "