Backport fixes and improvements from external Keras.

Change: 152198296
This commit is contained in:
Francois Chollet 2017-04-04 15:51:13 -08:00 committed by TensorFlower Gardener
parent 8f74d595ef
commit 9477900946
22 changed files with 424 additions and 150 deletions

View File

@ -37,4 +37,4 @@ from tensorflow.contrib.keras.python.keras import utils
from tensorflow.contrib.keras.python.keras import wrappers
__version__ = '2.0.0-tf'
__version__ = '2.0.2-tf'

View File

@ -24,18 +24,28 @@ from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
def softmax(x):
def softmax(x, axis=-1):
"""Softmax activation function.
Arguments:
x : Tensor.
axis: Integer, axis along which the softmax normalization is applied.
Returns:
Tensor, output of softmax transformation.
Raises:
ValueError: In case `dim(x) == 1`.
"""
ndim = K.ndim(x)
if ndim == 2:
return K.softmax(x)
elif ndim == 3:
e = K.exp(x - K.max(x, axis=-1, keepdims=True))
s = K.sum(e, axis=-1, keepdims=True)
elif ndim > 2:
e = K.exp(x - K.max(x, axis=axis, keepdims=True))
s = K.sum(e, axis=axis, keepdims=True)
return e / s
else:
raise ValueError('Cannot apply softmax to a tensor '
'that is not 2D or 3D. '
'Here, ndim=' + str(ndim))
raise ValueError('Cannot apply softmax to a tensor that is 1D')
def elu(x, alpha=1.0):

View File

@ -163,8 +163,8 @@ def ResNet50(include_top=True,
specified in your Keras config file.
Arguments:
include_top: whether to include the 3 fully-connected
layers at the top of the network.
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization)
or "imagenet" (pre-training on ImageNet).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)

View File

@ -22,7 +22,6 @@ from __future__ import division
from __future__ import print_function
from collections import defaultdict
import errno
import json
import os
import warnings
@ -270,6 +269,7 @@ def clear_session():
reset_uids()
_SESSION = None
phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
@ -1257,6 +1257,34 @@ def prod(x, axis=None, keepdims=False):
return math_ops.reduce_prod(x, reduction_indices=axis, keep_dims=keepdims)
def cumsum(x, axis=0):
"""Cumulative sum of the values in a tensor, alongside the specified axis.
Arguments:
x: A tensor or variable.
axis: An integer, the axis to compute the sum.
Returns:
A tensor of the cumulative sum of values of `x` along `axis`.
"""
axis = _normalize_axis(axis, ndim(x))
return math_ops.cumsum(x, axis=axis)
def cumprod(x, axis=0):
"""Cumulative product of the values in a tensor, alongside the specified axis.
Arguments:
x: A tensor or variable.
axis: An integer, the axis to compute the product.
Returns:
A tensor of the cumulative product of values of `x` along `axis`.
"""
axis = _normalize_axis(axis, ndim(x))
return math_ops.cumprod(x, axis=axis)
def var(x, axis=None, keepdims=False):
"""Variance of a tensor, alongside the specified axis.
@ -1330,8 +1358,7 @@ def any(x, axis=None, keepdims=False):
"""
axis = _normalize_axis(axis, ndim(x))
x = math_ops.cast(x, dtypes_module.bool)
x = math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
return math_ops.cast(x, dtypes_module.uint8)
return math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
def all(x, axis=None, keepdims=False):
@ -1347,8 +1374,7 @@ def all(x, axis=None, keepdims=False):
"""
axis = _normalize_axis(axis, ndim(x))
x = math_ops.cast(x, dtypes_module.bool)
x = math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
return math_ops.cast(x, dtypes_module.uint8)
return math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
def argmax(x, axis=-1):
@ -1645,7 +1671,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
"""
mean, var = nn.moments(
x, reduction_axes, shift=None, name=None, keep_dims=False)
if sorted(reduction_axes) == range(ndim(x))[:-1]:
if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
else:
# need broadcasting
@ -2324,8 +2350,8 @@ def rnn(step_function,
(no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over
the time dimension in reverse order.
go_backwards: boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: a list of constant values passed at each step.
@ -2414,9 +2440,9 @@ def rnn(step_function,
states = return_states
successive_outputs.append(output)
successive_states.append(states)
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = array_ops.stack(successive_outputs)
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = array_ops.stack(successive_outputs)
else:
for inp in input_list:
output, states = step_function(inp, states + constants)
@ -3534,19 +3560,19 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
# HIGH ORDER FUNCTIONS
def map_fn(fn, elems, name=None):
def map_fn(fn, elems, name=None, dtype=None):
"""Map the function fn over the elements elems and return the outputs.
Arguments:
fn: Callable that will be called upon each element in elems
elems: tensor
name: A string name for the map node in the graph
dtype: Output data type.
Returns:
Tensor with first dimension equal to the elems and second depending on
fn
Tensor with dtype `dtype`.
"""
return functional_ops.map_fn(fn, elems, name=name)
return functional_ops.map_fn(fn, elems, name=name, dtype=dtype)
def foldl(fn, elems, initializer=None, name=None):
@ -3560,7 +3586,7 @@ def foldl(fn, elems, initializer=None, name=None):
name: A string name for the foldl node in the graph
Returns:
Same type and shape as initializer
Tensor with same type and shape as `initializer`.
"""
return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
@ -3583,27 +3609,39 @@ def foldr(fn, elems, initializer=None, name=None):
# Load Keras default configuration from config file if present.
_keras_base_dir = os.path.expanduser('~')
if not os.access(_keras_base_dir, os.W_OK):
_keras_base_dir = '/tmp'
_keras_dir = os.path.join(_keras_base_dir, '.keras')
if not os.path.exists(_keras_dir):
try:
os.makedirs(_keras_dir)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
_config = json.load(open(_config_path))
try:
_config = json.load(open(_config_path))
except json.decoder.JSONDecodeError:
_config = {}
_floatx = _config.get('floatx', floatx())
assert _floatx in {'float16', 'float32', 'float64'}
_epsilon = _config.get('epsilon', epsilon())
assert isinstance(_epsilon, float)
_backend = backend()
_image_data_format = _config.get('image_data_format', image_data_format())
assert _image_data_format in {'channels_last', 'channels_first'}
set_floatx(_floatx)
set_epsilon(_epsilon)
set_image_data_format(_image_data_format)
# Save config file.
if os.access(_keras_base_dir, os.W_OK):
if not os.path.exists(_keras_dir):
try:
os.makedirs(_keras_dir)
except OSError:
# Except potential race conditions
# in multi-threaded environments.
pass
if not os.path.exists(_config_path):
_config = {
'floatx': floatx(),
'epsilon': epsilon(),
'backend': 'tensorflow',
'image_data_format': image_data_format()
}
with open(_config_path, 'w') as f:
f.write(json.dumps(_config, indent=4))

View File

@ -295,8 +295,14 @@ class Layer(object):
# are only applicable to input layers: do not pass these keywords
# to non-input layers.
allowed_kwargs = {
'input_shape', 'batch_input_shape', 'batch_size', 'dtype', 'name',
'trainable', 'weights'
'input_shape',
'batch_input_shape',
'batch_size',
'dtype',
'name',
'trainable',
'weights',
'input_dtype', # legacy
}
for kwarg in kwargs:
if kwarg not in allowed_kwargs:
@ -320,8 +326,15 @@ class Layer(object):
batch_size = None
batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
self.batch_input_shape = batch_input_shape
dtype = kwargs.get('dtype', K.floatx())
# Set dtype.
dtype = kwargs.get('dtype')
if dtype is None:
dtype = kwargs.get('input_dtype')
if dtype is None:
dtype = K.floatx()
self.dtype = dtype
if 'weights' in kwargs:
self._initial_weights = kwargs['weights']
else:
@ -485,11 +498,12 @@ class Layer(object):
': expected shape=' + str(spec.shape) +
', found shape=' + str(x_shape))
def call(self, inputs):
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""This is where the layer's logic lives.
Arguments:
inputs: input tensor, or list/tuple of input tensors.
inputs: Input tensor, or list/tuple of input tensors.
**kwargs: Additional keyword arguments.
Returns:
A tensor or list/tuple of tensors.
@ -518,6 +532,8 @@ class Layer(object):
ValueError: in case the layer is missing shape information
for its `build` call.
"""
if isinstance(inputs, list):
inputs = inputs[:]
with K.name_scope(self.name):
# Handle laying building (weight creating, input spec locking).
if not self.built:
@ -1417,7 +1433,7 @@ class Container(Layer):
get_weights
set_weights
get_config
get_output_shape_for
compute_output_shape
# Class Methods
from_config
@ -2029,7 +2045,7 @@ class Container(Layer):
for i in range(len(input_shapes)):
layer = self.input_layers[i]
input_shape = input_shapes[i]
# It's an input layer: get_output_shape_for is identity,
# It's an input layer: compute_output_shape is identity,
# and there is only one node and one tensor output.
shape_key = layer.name + '_0_0'
layers_to_output_shapes[shape_key] = input_shape

View File

@ -733,11 +733,12 @@ class Model(Container):
loss_functions = []
for name in self.output_names:
if name not in loss:
warnings.warn('Output "' + name + '" missing from loss dictionary. '
'We assume this was done on purpose, '
'and we will not be expecting '
'any data to be passed to "' + name +
'" during training.')
warnings.warn(
'Output "' + name + '" missing from loss dictionary. '
'We assume this was done on purpose, '
'and we will not be expecting '
'any data to be passed to "' + name + '" during training.',
stacklevel=2)
loss_functions.append(losses.get(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
@ -1202,7 +1203,7 @@ class Model(Container):
if batch_index == 0:
for batch_out in batch_outs:
shape = (samples,) + batch_out.shape[1:]
outs.append(np.zeros(shape, dtype=K.floatx()))
outs.append(np.zeros(shape, dtype=batch_out.dtype))
for i, batch_out in enumerate(batch_outs):
outs[i][batch_start:batch_end] = batch_out
@ -1718,7 +1719,7 @@ class Model(Container):
- a tuple (inputs, targets, sample_weights).
All arrays should contain the same number of samples.
The generator is expected to loop over its data
indefinitely. An epoch finishes when `samples_per_epoch`
indefinitely. An epoch finishes when `steps_per_epoch`
samples have been seen by the model.
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
@ -1767,7 +1768,7 @@ class Model(Container):
f.close()
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
samples_per_epoch=10000, epochs=10)
steps_per_epoch=10000, epochs=10)
```
Raises:
@ -2028,7 +2029,8 @@ class Model(Container):
steps,
max_q_size=10,
workers=1,
pickle_safe=False):
pickle_safe=False,
verbose=0):
"""Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
@ -2048,6 +2050,7 @@ class Model(Container):
non picklable arguments to the generator
as they can't be passed
easily to children processes.
verbose: verbosity mode, 0 or 1.
Returns:
Numpy array(s) of predictions.
@ -2067,6 +2070,9 @@ class Model(Container):
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
enqueuer.start(workers=workers, max_q_size=max_q_size)
if verbose == 1:
progbar = Progbar(target=steps)
while steps_done < steps:
generator_output = None
while enqueuer.is_running():
@ -2103,6 +2109,8 @@ class Model(Container):
for i, out in enumerate(outs):
all_outs[i].append(out)
steps_done += 1
if verbose == 1:
progbar.update(steps_done)
finally:
if enqueuer is not None:

View File

@ -45,14 +45,16 @@ class Initializer(object):
class Zeros(Initializer):
"""Initializer that generates tensors initialized to 0."""
"""Initializer that generates tensors initialized to 0.
"""
def __call__(self, shape, dtype=None):
return K.constant(0, shape=shape, dtype=dtype)
class Ones(Initializer):
"""Initializer that generates tensors initialized to 1."""
"""Initializer that generates tensors initialized to 1.
"""
def __call__(self, shape, dtype=None):
return K.constant(1, shape=shape, dtype=dtype)
@ -130,7 +132,7 @@ class RandomUniform(Initializer):
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
These values are similar to values from a `random_normal_initializer`
These values are similar to values from a `RandomNormal`
except that values more than two standard deviations from the mean
are discarded and re-drawn. This is the recommended initializer for
neural network weights and filters.
@ -161,6 +163,7 @@ class VarianceScaling(Initializer):
With `distribution="normal"`, samples are drawn from a truncated normal
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
- number of input units in the weight tensor, if mode = "fan_in"
- number of output units, if mode = "fan_out"
- average of the numbers of input and output units, if mode = "fan_avg"

View File

@ -244,7 +244,7 @@ class _Conv(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
initializers.serialize(self.kernel_initializer),
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
@ -289,7 +289,7 @@ class Conv1D(_Conv):
any `dilation_rate` value != 1.
padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive).
`"causal"` results in causal (dilated) convolutions, e.g. output[t]
depends solely on input[:t-1]. Useful when modeling temporal data
does not depend on input[t+1:]. Useful when modeling temporal data
where the model should not violate the temporal order.
See [WaveNet: A Generative Model for Raw Audio, section
2.1](https://arxiv.org/abs/1609.03499).
@ -395,9 +395,9 @@ class Conv2D(_Conv):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -621,7 +621,7 @@ class Conv2DTranspose(Conv2D):
Arguments:
filters: Integer, the dimensionality of the output space
(i.e. the number output of filters in the convolution).
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
Can be a single integer to specify the same value for
@ -637,9 +637,9 @@ class Conv2DTranspose(Conv2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -688,7 +688,7 @@ class Conv2DTranspose(Conv2D):
kernel_size,
strides=(1, 1),
padding='valid',
data_format='channels_last',
data_format=None,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
@ -845,9 +845,9 @@ class SeparableConv2D(Conv2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -1079,9 +1079,9 @@ class UpSampling2D(Layer):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -1257,7 +1257,7 @@ class ZeroPadding2D(Layer):
- If tuple of 2 ints:
interpreted as two different
symmetric padding values for height and width:
`(symmetric_height_pad, symmetrc_width_pad)`.
`(symmetric_height_pad, symmetric_width_pad)`.
- If tuple of 2 tuples of 2 ints:
interpreted as
`((top_pad, bottom_pad), (left_pad, right_pad))`
@ -1265,9 +1265,9 @@ class ZeroPadding2D(Layer):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -1498,7 +1498,7 @@ class Cropping2D(Layer):
- If tuple of 2 ints:
interpreted as two different
symmetric cropping values for height and width:
`(symmetric_height_crop, symmetrc_width_crop)`.
`(symmetric_height_crop, symmetric_width_crop)`.
- If tuple of 2 tuples of 2 ints:
interpreted as
`((top_crop, bottom_crop), (left_crop, right_crop))`
@ -1506,9 +1506,9 @@ class Cropping2D(Layer):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".

View File

@ -357,7 +357,7 @@ class ConvLSTM2D(ConvRecurrent2D):
self.states = [None, None]
if self.data_format == 'channels_first':
channel_axis = 1
channel_axis = 2
else:
channel_axis = -1
if input_shape[channel_axis] is None:

View File

@ -88,7 +88,7 @@ class Dropout(Layer):
"""Applies Dropout to the input.
Dropout consists in randomly setting
a fraction `p` of input units to 0 at each update during training time,
a fraction `rate` of input units to 0 at each update during training time,
which helps prevent overfitting.
Arguments:
@ -140,7 +140,7 @@ class SpatialDropout1D(Dropout):
between feature maps and should be used instead.
Arguments:
p: float between 0 and 1. Fraction of the input units to drop.
rate: float between 0 and 1. Fraction of the input units to drop.
Input shape:
3D tensor with shape:
@ -775,7 +775,7 @@ class Dense(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
initializers.serialize(self.kernel_initializer),
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':

View File

@ -59,7 +59,8 @@ class LocallyConnected1D(Layer):
specifying the stride length of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: One of `"valid"` or `"same"` (case-insensitive).
padding: Currently only supports `"valid"` (case-insensitive).
`"same"` may be supported in the future.
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
@ -188,7 +189,7 @@ class LocallyConnected1D(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
initializers.serialize(self.kernel_initializer),
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
@ -239,16 +240,15 @@ class LocallyConnected2D(Layer):
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
padding: Currently only support `"valid"` (case-insensitive).
`"same"` will be supported in future.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -460,7 +460,7 @@ class LocallyConnected2D(Layer):
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
initializers.serialize(self.kernel_initializer),
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':

View File

@ -41,6 +41,44 @@ class _Merge(Layer):
def _merge_function(self, inputs):
raise NotImplementedError
def _compute_elemwise_op_output_shape(self, shape1, shape2):
"""Computes the shape of the resultant of an elementwise operation.
Arguments:
shape1: tuple or None. Shape of the first tensor
shape2: tuple or None. Shape of the second tensor
Returns:
expected output shape when an element-wise operation is
carried out on 2 tensors with shapes shape1 and shape2.
tuple or None.
Raises:
ValueError: if shape1 and shape2 are not compatible for
element-wise operations.
"""
if None in [shape1, shape2]:
return None
elif len(shape1) < len(shape2):
return self._compute_elemwise_op_output_shape(shape2, shape1)
elif not shape2:
return shape1
output_shape = list(shape1[:-len(shape2)])
for i, j in zip(shape1[-len(shape2):], shape2):
if i is None or j is None:
output_shape.append(None)
elif i == 1:
output_shape.append(j)
elif j == 1:
output_shape.append(i)
else:
if i != j:
raise ValueError('Operands could not be broadcast '
'together with shapes ' + str(shape1) + ' ' + str(
shape2))
output_shape.append(i)
return tuple(output_shape)
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list):
@ -49,23 +87,107 @@ class _Merge(Layer):
raise ValueError('A merge layer should be called '
'on a list of at least 2 inputs. '
'Got ' + str(len(input_shape)) + ' inputs.')
if all([shape is None for shape in input_shape]):
return
input_shapes = [
tuple(tensor_shape.TensorShape(shape).as_list())
for shape in input_shape
]
# TODO(fchollet): handle shapes with None entries.
input_shapes_set = set(input_shapes)
if None in input_shapes_set:
input_shapes_set.remove(None)
if len(input_shapes_set) > 1:
raise ValueError('Only tensors of same shape can '
'be merged by layer' + self.name +
' Got input shapes: %s' % input_shapes)
batch_sizes = [s[0] for s in input_shape if s is not None]
batch_sizes = set(batch_sizes)
batch_sizes -= set([None])
if len(batch_sizes) > 1:
raise ValueError('Can not merge tensors with different '
'batch sizes. Got tensors with shapes : ' + str(
input_shape))
if input_shape[0] is None:
output_shape = None
else:
output_shape = input_shape[0][1:]
for i in range(1, len(input_shape)):
if input_shape[i] is None:
shape = None
else:
shape = input_shape[i][1:]
output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
# If the inputs have different ranks, we have to reshape them
# to make them broadcastable.
if None not in input_shape and len(set(map(len, input_shape))) == 1:
self._reshape_required = False
else:
self._reshape_required = True
def call(self, inputs):
return self._merge_function(inputs)
if self._reshape_required:
reshaped_inputs = []
input_ndims = list(map(K.ndim, inputs))
if None not in input_ndims:
# If ranks of all inputs are available,
# we simply expand each of them at axis=1
# until all of them have the same rank.
max_ndim = max(input_ndims)
for x in inputs:
x_ndim = K.ndim(x)
for _ in range(max_ndim - x_ndim):
x = K.expand_dims(x, 1)
reshaped_inputs.append(x)
return self._merge_function(reshaped_inputs)
else:
# Transpose all inputs so that batch size is the last dimension.
# (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
transposed = False
for x in inputs:
x_ndim = K.ndim(x)
if x_ndim is None:
x_shape = K.shape(x)
batch_size = x_shape[0]
new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
x_transposed = K.reshape(x,
K.stack([batch_size, K.prod(x_shape[1:])]))
x_transposed = K.permute_dimensions(x_transposed, (1, 0))
x_transposed = K.reshape(x_transposed, new_shape)
reshaped_inputs.append(x_transposed)
transposed = True
elif x_ndim > 1:
dims = list(range(1, x_ndim)) + [0]
reshaped_inputs.append(K.permute_dimensions(x, dims))
transposed = True
else:
# We don't transpose inputs if they are 1D vectors or scalars.
reshaped_inputs.append(x)
y = self._merge_function(reshaped_inputs)
y_ndim = K.ndim(y)
if transposed:
# If inputs have been transposed, we have to transpose the output too.
if y_ndim is None:
y_shape = K.shape(y)
y_ndim = K.shape(y_shape)[0]
batch_size = y_shape[y_ndim - 1]
new_shape = K.concatenate(
[K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
y = K.reshape(y, (-1, batch_size))
y = K.permute_dimensions(y, (1, 0))
y = K.reshape(y, new_shape)
elif y_ndim > 1:
dims = [y_ndim - 1] + list(range(y_ndim - 1))
y = K.permute_dimensions(y, dims)
return y
else:
return self._merge_function(inputs)
def compute_output_shape(self, input_shape):
if input_shape[0] is None:
output_shape = None
else:
output_shape = input_shape[0][1:]
for i in range(1, len(input_shape)):
if input_shape[i] is None:
shape = None
else:
shape = input_shape[i][1:]
output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
batch_sizes = [s[0] for s in input_shape if s is not None]
batch_sizes = set(batch_sizes)
batch_sizes -= set([None])
if len(batch_sizes) == 1:
output_shape = (list(batch_sizes)[0],) + output_shape
else:
output_shape = (None,) + output_shape
return output_shape
def compute_mask(self, inputs, mask=None):
if mask is None:
@ -219,8 +341,8 @@ class Concatenate(_Merge):
for input_i, mask_i in zip(inputs, mask):
if mask_i is None:
# Input is unmasked. Append all 1s to masks,
# but cast it to uint8 first
masks.append(K.cast(K.ones_like(input_i), 'uint8'))
# but cast it to bool first
masks.append(K.cast(K.ones_like(input_i), 'bool'))
elif K.ndim(mask_i) < K.ndim(input_i):
# Mask is smaller than the input, expand it
masks.append(K.expand_dims(mask_i))

View File

@ -154,7 +154,7 @@ class BatchNormalization(Layer):
broadcast_shape[self.axis] = input_shape[self.axis]
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
normed, mean, variance = K.normalize_batch_in_training(
inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)

View File

@ -199,9 +199,9 @@ class MaxPooling2D(_Pooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -255,9 +255,9 @@ class AveragePooling2D(_Pooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -542,9 +542,9 @@ class GlobalAveragePooling2D(_GlobalPooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@ -577,9 +577,9 @@ class GlobalMaxPooling2D(_GlobalPooling2D):
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, width, height, channels)` while `channels_first`
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, width, height)`.
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".

View File

@ -105,8 +105,16 @@ class Recurrent(Layer):
# now model.output_shape == (None, 32)
# note: `None` is the batch dimension.
# for subsequent layers, not need to specify the input size:
# for subsequent layers, no need to specify the input size:
model.add(LSTM(16))
# to stack recurrent layers, you must use return_sequences=True
# on any recurrent layer that feeds into another recurrent layer.
# note that you only need to specify the input size on the first layer.
model = Sequential()
model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True))
model.add(LSTM(32, return_sequences=True))
model.add(LSTM(10))
```
Arguments:
@ -116,7 +124,8 @@ class Recurrent(Layer):
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
go_backwards: Boolean (default False).
If True, process the input sequence backwards.
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
@ -398,6 +407,7 @@ class SimpleRNN(Recurrent):
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
If you don't specify anything, no activation is applied
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
@ -547,7 +557,7 @@ class SimpleRNN(Recurrent):
def get_constants(self, inputs, training=None):
constants = []
if self.implementation == 0 and 0 < self.dropout < 1:
if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
@ -619,7 +629,7 @@ class GRU(Recurrent):
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
If you don't specify anything, no activation is applied
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
@ -792,7 +802,7 @@ class GRU(Recurrent):
def get_constants(self, inputs, training=None):
constants = []
if self.implementation == 0 and 0 < self.dropout < 1:
if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
@ -861,7 +871,7 @@ class GRU(Recurrent):
if self.use_bias:
x_z = K.bias_add(x_z, self.bias_z)
x_r = K.bias_add(x_r, self.bias_r)
x_h = K.bias_add(x_r, self.bias_h)
x_h = K.bias_add(x_h, self.bias_h)
else:
raise ValueError('Unknown `implementation` mode.')
z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
@ -924,7 +934,7 @@ class LSTM(Recurrent):
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
If you don't specify anything, no activation is applied
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
@ -1127,7 +1137,7 @@ class LSTM(Recurrent):
def get_constants(self, inputs, training=None):
constants = []
if self.implementation == 0 and 0 < self.dropout < 1:
if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Wrapper layers: layers that augment the functionality of another layer.
"""
from __future__ import absolute_import
@ -19,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import copy
import inspect
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.engine import InputSpec
@ -70,9 +72,10 @@ class Wrapper(Layer):
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
def from_config(cls, config, custom_objects=None):
from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
layer = deserialize_layer(config.pop('layer'))
layer = deserialize_layer(
config.pop('layer'), custom_objects=custom_objects)
return cls(layer, **config)
@ -188,12 +191,15 @@ class Bidirectional(Wrapper):
If None, the outputs will not be combined,
they will be returned as a list.
Raises:
ValueError: In case of invalid `merge_mode` argument.
Examples:
```python
model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5,
10)))
10)))
model.add(Bidirectional(LSTM(10)))
model.add(Dense(5))
model.add(Activation('softmax'))
@ -242,29 +248,47 @@ class Bidirectional(Wrapper):
shape = self.forward_layer._compute_output_shape(input_shape) # pylint: disable=protected-access
return [shape, copy.copy(shape)]
def call(self, inputs, mask=None):
y = self.forward_layer.call(inputs, mask)
y_rev = self.backward_layer.call(inputs, mask)
def call(self, inputs, training=None, mask=None):
kwargs = {}
func_args = inspect.getargspec(self.layer.call).args
if 'training' in func_args:
kwargs['training'] = training
if 'mask' in func_args:
kwargs['mask'] = mask
y = self.forward_layer.call(inputs, **kwargs)
y_rev = self.backward_layer.call(inputs, **kwargs)
if self.return_sequences:
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
return K.concatenate([y, y_rev])
output = K.concatenate([y, y_rev])
elif self.merge_mode == 'sum':
return y + y_rev
output = y + y_rev
elif self.merge_mode == 'ave':
return (y + y_rev) / 2
output = (y + y_rev) / 2
elif self.merge_mode == 'mul':
return y * y_rev
output = y * y_rev
elif self.merge_mode is None:
return [y, y_rev]
output = [y, y_rev]
# Properly set learning phase
if 0 < self.layer.dropout + self.layer.recurrent_dropout:
if self.merge_mode is None:
for out in output:
out._uses_learning_phase = True
else:
output._uses_learning_phase = True
return output
def reset_states(self):
self.forward_layer.reset_states()
self.backward_layer.reset_states()
def build(self, input_shape):
self.forward_layer.build(input_shape)
self.backward_layer.build(input_shape)
with K.name_scope(self.forward_layer.name):
self.forward_layer.build(input_shape)
with K.name_scope(self.backward_layer.name):
self.backward_layer.build(input_shape)
self.built = True
def compute_mask(self, inputs, mask):

View File

@ -43,12 +43,15 @@ def binary_accuracy(y_true, y_pred):
def categorical_accuracy(y_true, y_pred):
return K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1))
return K.cast(
K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx())
def sparse_categorical_accuracy(y_true, y_pred):
return K.equal(
K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), K.floatx()))
return K.cast(
K.equal(
K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1),
K.floatx())), K.floatx())
def top_k_categorical_accuracy(y_true, y_pred, k=5):

View File

@ -207,7 +207,7 @@ def load_model(filepath, custom_objects=None):
ValueError: In case of an invalid savefile.
"""
if h5py is None:
raise ImportError('`save_model` requires h5py.')
raise ImportError('`load_model` requires h5py.')
if not custom_objects:
custom_objects = {}
@ -1006,7 +1006,7 @@ class Sequential(Model):
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
finished and starting the next epoch. It should typically
be equal to the number of unique samples if your dataset
be equal to the number of unique samples of your dataset
divided by the batch size.
epochs: Integer, total number of iterations on the data.
verbose: Verbosity mode, 0, 1, or 2.
@ -1017,8 +1017,10 @@ class Sequential(Model):
- A tuple (inputs, targets, sample_weights).
validation_steps: Only relevant if `validation_data`
is a generator.
Number of samples to use from validation generator
at the end of every epoch.
Number of steps to yield from validation generator
at the end of every epoch. It should typically
be equal to the number of unique samples of your
validation dataset divided by the batch size.
class_weight: Dictionary mapping class indices to a weight
for the class.
max_q_size: Maximum size for the generator queue
@ -1050,7 +1052,7 @@ class Sequential(Model):
# and labels, from each line in the file
x, y = process_line(line)
yield (x, y)
f.close()
f.close()
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
samples_per_epoch=10000, epochs=10)
@ -1119,7 +1121,8 @@ class Sequential(Model):
steps,
max_q_size=10,
workers=1,
pickle_safe=False):
pickle_safe=False,
verbose=0):
"""Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
@ -1136,6 +1139,7 @@ class Sequential(Model):
relies on multiprocessing, you should not pass
non picklable arguments to the generator
as they can't be passed easily to children processes.
verbose: verbosity mode, 0 or 1.
Returns:
A Numpy array of predictions.
@ -1147,7 +1151,8 @@ class Sequential(Model):
steps,
max_q_size=max_q_size,
workers=workers,
pickle_safe=pickle_safe)
pickle_safe=pickle_safe,
verbose=verbose)
def get_config(self):
config = []
@ -1159,9 +1164,9 @@ class Sequential(Model):
return copy.deepcopy(config)
@classmethod
def from_config(cls, config):
def from_config(cls, config, custom_objects=None):
model = cls()
for conf in config:
layer = layer_module.deserialize(conf)
layer = layer_module.deserialize(conf, custom_objects=custom_objects)
model.add(layer)
return model

View File

@ -785,7 +785,7 @@ class Iterator(object):
index_array = np.random.permutation(n)
current_index = (self.batch_index * batch_size) % n
if n >= current_index + batch_size:
if n > current_index + batch_size:
current_batch_size = batch_size
self.batch_index += 1
else:

View File

@ -172,7 +172,8 @@ def deserialize_keras_object(identifier,
else:
fn = module_objects.get(function_name)
if fn is None:
raise ValueError('Unknown ' + printable_module_name, ':' + class_name)
raise ValueError('Unknown ' + printable_module_name,
':' + function_name)
return fn
else:
raise ValueError('Could not interpret serialized ' + printable_module_name +
@ -215,6 +216,8 @@ def func_load(code, defaults=None, closure=None, globs=None):
"""
if isinstance(code, (tuple, list)): # unpack previous dump
code, defaults, closure = code
if isinstance(defaults, list):
defaults = tuple(defaults)
code = marshal.loads(code.encode('raw_unicode_escape'))
if globs is None:
globs = globals()

View File

@ -171,7 +171,7 @@ def count_total_params(layers, layer_set=None):
[K.count_params(p) for p in layer.trainable_weights])
non_trainable_count += np.sum(
[K.count_params(p) for p in layer.non_trainable_weights])
return trainable_count, non_trainable_count
return int(trainable_count), int(non_trainable_count)
def convert_all_kernels_in_model(model):

View File

@ -194,6 +194,36 @@ class KerasClassifier(BaseWrapper):
"""Implementation of the scikit-learn classifier API for Keras.
"""
def fit(self, x, y, **kwargs):
"""Constructs a new model with `build_fn` & fit the model to `(x, y)`.
Arguments:
x : array-like, shape `(n_samples, n_features)`
Training samples where n_samples in the number of samples
and n_features is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
True labels for X.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.fit`
Returns:
history : object
details about the training history at each epoch.
Raises:
ValueError: In case of invalid shape for `y` argument.
"""
y = np.array(y)
if len(y.shape) == 2 and y.shape[1] > 1:
self.classes_ = np.arange(y.shape[1])
elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
self.classes_ = np.unique(y)
y = np.searchsorted(self.classes_, y)
else:
raise ValueError('Invalid shape for y: ' + str(y.shape))
self.n_classes_ = len(self.classes_)
return super(KerasClassifier, self).fit(x, y, **kwargs)
def predict(self, x, **kwargs):
"""Returns the class predictions for the given test data.
@ -210,7 +240,8 @@ class KerasClassifier(BaseWrapper):
Class predictions.
"""
kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
return self.model.predict_classes(x, **kwargs)
classes = self.model.predict_classes(x, **kwargs)
return self.classes_[classes]
def predict_proba(self, x, **kwargs):
"""Returns class probability estimates for the given test data.
@ -261,6 +292,7 @@ class KerasClassifier(BaseWrapper):
compute accuracy. You should pass `metrics=["accuracy"]` to
the `.compile()` method of the model.
"""
y = np.searchsorted(self.classes_, y)
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
loss_name = self.model.loss