Backport fixes and improvements from external Keras.
Change: 152198296
This commit is contained in:
parent
8f74d595ef
commit
9477900946
@ -37,4 +37,4 @@ from tensorflow.contrib.keras.python.keras import utils
|
|||||||
from tensorflow.contrib.keras.python.keras import wrappers
|
from tensorflow.contrib.keras.python.keras import wrappers
|
||||||
|
|
||||||
|
|
||||||
__version__ = '2.0.0-tf'
|
__version__ = '2.0.2-tf'
|
||||||
|
@ -24,18 +24,28 @@ from tensorflow.contrib.keras.python.keras import backend as K
|
|||||||
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
|
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
|
||||||
|
|
||||||
|
|
||||||
def softmax(x):
|
def softmax(x, axis=-1):
|
||||||
|
"""Softmax activation function.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x : Tensor.
|
||||||
|
axis: Integer, axis along which the softmax normalization is applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, output of softmax transformation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: In case `dim(x) == 1`.
|
||||||
|
"""
|
||||||
ndim = K.ndim(x)
|
ndim = K.ndim(x)
|
||||||
if ndim == 2:
|
if ndim == 2:
|
||||||
return K.softmax(x)
|
return K.softmax(x)
|
||||||
elif ndim == 3:
|
elif ndim > 2:
|
||||||
e = K.exp(x - K.max(x, axis=-1, keepdims=True))
|
e = K.exp(x - K.max(x, axis=axis, keepdims=True))
|
||||||
s = K.sum(e, axis=-1, keepdims=True)
|
s = K.sum(e, axis=axis, keepdims=True)
|
||||||
return e / s
|
return e / s
|
||||||
else:
|
else:
|
||||||
raise ValueError('Cannot apply softmax to a tensor '
|
raise ValueError('Cannot apply softmax to a tensor that is 1D')
|
||||||
'that is not 2D or 3D. '
|
|
||||||
'Here, ndim=' + str(ndim))
|
|
||||||
|
|
||||||
|
|
||||||
def elu(x, alpha=1.0):
|
def elu(x, alpha=1.0):
|
||||||
|
@ -163,8 +163,8 @@ def ResNet50(include_top=True,
|
|||||||
specified in your Keras config file.
|
specified in your Keras config file.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
include_top: whether to include the 3 fully-connected
|
include_top: whether to include the fully-connected
|
||||||
layers at the top of the network.
|
layer at the top of the network.
|
||||||
weights: one of `None` (random initialization)
|
weights: one of `None` (random initialization)
|
||||||
or "imagenet" (pre-training on ImageNet).
|
or "imagenet" (pre-training on ImageNet).
|
||||||
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
|
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
|
||||||
|
@ -22,7 +22,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import errno
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -270,6 +269,7 @@ def clear_session():
|
|||||||
reset_uids()
|
reset_uids()
|
||||||
_SESSION = None
|
_SESSION = None
|
||||||
phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
|
phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
|
||||||
|
_GRAPH_LEARNING_PHASES = {}
|
||||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
|
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
|
||||||
|
|
||||||
|
|
||||||
@ -1257,6 +1257,34 @@ def prod(x, axis=None, keepdims=False):
|
|||||||
return math_ops.reduce_prod(x, reduction_indices=axis, keep_dims=keepdims)
|
return math_ops.reduce_prod(x, reduction_indices=axis, keep_dims=keepdims)
|
||||||
|
|
||||||
|
|
||||||
|
def cumsum(x, axis=0):
|
||||||
|
"""Cumulative sum of the values in a tensor, alongside the specified axis.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: A tensor or variable.
|
||||||
|
axis: An integer, the axis to compute the sum.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of the cumulative sum of values of `x` along `axis`.
|
||||||
|
"""
|
||||||
|
axis = _normalize_axis(axis, ndim(x))
|
||||||
|
return math_ops.cumsum(x, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def cumprod(x, axis=0):
|
||||||
|
"""Cumulative product of the values in a tensor, alongside the specified axis.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: A tensor or variable.
|
||||||
|
axis: An integer, the axis to compute the product.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of the cumulative product of values of `x` along `axis`.
|
||||||
|
"""
|
||||||
|
axis = _normalize_axis(axis, ndim(x))
|
||||||
|
return math_ops.cumprod(x, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
def var(x, axis=None, keepdims=False):
|
def var(x, axis=None, keepdims=False):
|
||||||
"""Variance of a tensor, alongside the specified axis.
|
"""Variance of a tensor, alongside the specified axis.
|
||||||
|
|
||||||
@ -1330,8 +1358,7 @@ def any(x, axis=None, keepdims=False):
|
|||||||
"""
|
"""
|
||||||
axis = _normalize_axis(axis, ndim(x))
|
axis = _normalize_axis(axis, ndim(x))
|
||||||
x = math_ops.cast(x, dtypes_module.bool)
|
x = math_ops.cast(x, dtypes_module.bool)
|
||||||
x = math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
|
return math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
|
||||||
return math_ops.cast(x, dtypes_module.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
def all(x, axis=None, keepdims=False):
|
def all(x, axis=None, keepdims=False):
|
||||||
@ -1347,8 +1374,7 @@ def all(x, axis=None, keepdims=False):
|
|||||||
"""
|
"""
|
||||||
axis = _normalize_axis(axis, ndim(x))
|
axis = _normalize_axis(axis, ndim(x))
|
||||||
x = math_ops.cast(x, dtypes_module.bool)
|
x = math_ops.cast(x, dtypes_module.bool)
|
||||||
x = math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
|
return math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
|
||||||
return math_ops.cast(x, dtypes_module.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
def argmax(x, axis=-1):
|
def argmax(x, axis=-1):
|
||||||
@ -1645,7 +1671,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
|
|||||||
"""
|
"""
|
||||||
mean, var = nn.moments(
|
mean, var = nn.moments(
|
||||||
x, reduction_axes, shift=None, name=None, keep_dims=False)
|
x, reduction_axes, shift=None, name=None, keep_dims=False)
|
||||||
if sorted(reduction_axes) == range(ndim(x))[:-1]:
|
if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
|
||||||
normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
|
normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
|
||||||
else:
|
else:
|
||||||
# need broadcasting
|
# need broadcasting
|
||||||
@ -2324,8 +2350,8 @@ def rnn(step_function,
|
|||||||
(no time dimension),
|
(no time dimension),
|
||||||
containing the initial values for the states used in
|
containing the initial values for the states used in
|
||||||
the step function.
|
the step function.
|
||||||
go_backwards: boolean. If True, do the iteration over
|
go_backwards: boolean. If True, do the iteration over the time
|
||||||
the time dimension in reverse order.
|
dimension in reverse order and return the reversed sequence.
|
||||||
mask: binary tensor with shape `(samples, time, 1)`,
|
mask: binary tensor with shape `(samples, time, 1)`,
|
||||||
with a zero for every element that is masked.
|
with a zero for every element that is masked.
|
||||||
constants: a list of constant values passed at each step.
|
constants: a list of constant values passed at each step.
|
||||||
@ -2414,9 +2440,9 @@ def rnn(step_function,
|
|||||||
states = return_states
|
states = return_states
|
||||||
successive_outputs.append(output)
|
successive_outputs.append(output)
|
||||||
successive_states.append(states)
|
successive_states.append(states)
|
||||||
last_output = successive_outputs[-1]
|
last_output = successive_outputs[-1]
|
||||||
new_states = successive_states[-1]
|
new_states = successive_states[-1]
|
||||||
outputs = array_ops.stack(successive_outputs)
|
outputs = array_ops.stack(successive_outputs)
|
||||||
else:
|
else:
|
||||||
for inp in input_list:
|
for inp in input_list:
|
||||||
output, states = step_function(inp, states + constants)
|
output, states = step_function(inp, states + constants)
|
||||||
@ -3534,19 +3560,19 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
|||||||
# HIGH ORDER FUNCTIONS
|
# HIGH ORDER FUNCTIONS
|
||||||
|
|
||||||
|
|
||||||
def map_fn(fn, elems, name=None):
|
def map_fn(fn, elems, name=None, dtype=None):
|
||||||
"""Map the function fn over the elements elems and return the outputs.
|
"""Map the function fn over the elements elems and return the outputs.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
fn: Callable that will be called upon each element in elems
|
fn: Callable that will be called upon each element in elems
|
||||||
elems: tensor
|
elems: tensor
|
||||||
name: A string name for the map node in the graph
|
name: A string name for the map node in the graph
|
||||||
|
dtype: Output data type.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor with first dimension equal to the elems and second depending on
|
Tensor with dtype `dtype`.
|
||||||
fn
|
|
||||||
"""
|
"""
|
||||||
return functional_ops.map_fn(fn, elems, name=name)
|
return functional_ops.map_fn(fn, elems, name=name, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def foldl(fn, elems, initializer=None, name=None):
|
def foldl(fn, elems, initializer=None, name=None):
|
||||||
@ -3560,7 +3586,7 @@ def foldl(fn, elems, initializer=None, name=None):
|
|||||||
name: A string name for the foldl node in the graph
|
name: A string name for the foldl node in the graph
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Same type and shape as initializer
|
Tensor with same type and shape as `initializer`.
|
||||||
"""
|
"""
|
||||||
return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
|
return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
|
||||||
|
|
||||||
@ -3583,27 +3609,39 @@ def foldr(fn, elems, initializer=None, name=None):
|
|||||||
|
|
||||||
# Load Keras default configuration from config file if present.
|
# Load Keras default configuration from config file if present.
|
||||||
_keras_base_dir = os.path.expanduser('~')
|
_keras_base_dir = os.path.expanduser('~')
|
||||||
if not os.access(_keras_base_dir, os.W_OK):
|
|
||||||
_keras_base_dir = '/tmp'
|
|
||||||
_keras_dir = os.path.join(_keras_base_dir, '.keras')
|
_keras_dir = os.path.join(_keras_base_dir, '.keras')
|
||||||
if not os.path.exists(_keras_dir):
|
|
||||||
try:
|
|
||||||
os.makedirs(_keras_dir)
|
|
||||||
except OSError as e:
|
|
||||||
if e.errno == errno.EEXIST:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
|
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
|
||||||
if os.path.exists(_config_path):
|
if os.path.exists(_config_path):
|
||||||
_config = json.load(open(_config_path))
|
try:
|
||||||
|
_config = json.load(open(_config_path))
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
_config = {}
|
||||||
_floatx = _config.get('floatx', floatx())
|
_floatx = _config.get('floatx', floatx())
|
||||||
assert _floatx in {'float16', 'float32', 'float64'}
|
assert _floatx in {'float16', 'float32', 'float64'}
|
||||||
_epsilon = _config.get('epsilon', epsilon())
|
_epsilon = _config.get('epsilon', epsilon())
|
||||||
assert isinstance(_epsilon, float)
|
assert isinstance(_epsilon, float)
|
||||||
_backend = backend()
|
|
||||||
_image_data_format = _config.get('image_data_format', image_data_format())
|
_image_data_format = _config.get('image_data_format', image_data_format())
|
||||||
assert _image_data_format in {'channels_last', 'channels_first'}
|
assert _image_data_format in {'channels_last', 'channels_first'}
|
||||||
set_floatx(_floatx)
|
set_floatx(_floatx)
|
||||||
set_epsilon(_epsilon)
|
set_epsilon(_epsilon)
|
||||||
set_image_data_format(_image_data_format)
|
set_image_data_format(_image_data_format)
|
||||||
|
|
||||||
|
# Save config file.
|
||||||
|
if os.access(_keras_base_dir, os.W_OK):
|
||||||
|
if not os.path.exists(_keras_dir):
|
||||||
|
try:
|
||||||
|
os.makedirs(_keras_dir)
|
||||||
|
except OSError:
|
||||||
|
# Except potential race conditions
|
||||||
|
# in multi-threaded environments.
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not os.path.exists(_config_path):
|
||||||
|
_config = {
|
||||||
|
'floatx': floatx(),
|
||||||
|
'epsilon': epsilon(),
|
||||||
|
'backend': 'tensorflow',
|
||||||
|
'image_data_format': image_data_format()
|
||||||
|
}
|
||||||
|
with open(_config_path, 'w') as f:
|
||||||
|
f.write(json.dumps(_config, indent=4))
|
||||||
|
@ -295,8 +295,14 @@ class Layer(object):
|
|||||||
# are only applicable to input layers: do not pass these keywords
|
# are only applicable to input layers: do not pass these keywords
|
||||||
# to non-input layers.
|
# to non-input layers.
|
||||||
allowed_kwargs = {
|
allowed_kwargs = {
|
||||||
'input_shape', 'batch_input_shape', 'batch_size', 'dtype', 'name',
|
'input_shape',
|
||||||
'trainable', 'weights'
|
'batch_input_shape',
|
||||||
|
'batch_size',
|
||||||
|
'dtype',
|
||||||
|
'name',
|
||||||
|
'trainable',
|
||||||
|
'weights',
|
||||||
|
'input_dtype', # legacy
|
||||||
}
|
}
|
||||||
for kwarg in kwargs:
|
for kwarg in kwargs:
|
||||||
if kwarg not in allowed_kwargs:
|
if kwarg not in allowed_kwargs:
|
||||||
@ -320,8 +326,15 @@ class Layer(object):
|
|||||||
batch_size = None
|
batch_size = None
|
||||||
batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
|
batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
|
||||||
self.batch_input_shape = batch_input_shape
|
self.batch_input_shape = batch_input_shape
|
||||||
dtype = kwargs.get('dtype', K.floatx())
|
|
||||||
|
# Set dtype.
|
||||||
|
dtype = kwargs.get('dtype')
|
||||||
|
if dtype is None:
|
||||||
|
dtype = kwargs.get('input_dtype')
|
||||||
|
if dtype is None:
|
||||||
|
dtype = K.floatx()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
if 'weights' in kwargs:
|
if 'weights' in kwargs:
|
||||||
self._initial_weights = kwargs['weights']
|
self._initial_weights = kwargs['weights']
|
||||||
else:
|
else:
|
||||||
@ -485,11 +498,12 @@ class Layer(object):
|
|||||||
': expected shape=' + str(spec.shape) +
|
': expected shape=' + str(spec.shape) +
|
||||||
', found shape=' + str(x_shape))
|
', found shape=' + str(x_shape))
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||||
"""This is where the layer's logic lives.
|
"""This is where the layer's logic lives.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
inputs: input tensor, or list/tuple of input tensors.
|
inputs: Input tensor, or list/tuple of input tensors.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor or list/tuple of tensors.
|
A tensor or list/tuple of tensors.
|
||||||
@ -518,6 +532,8 @@ class Layer(object):
|
|||||||
ValueError: in case the layer is missing shape information
|
ValueError: in case the layer is missing shape information
|
||||||
for its `build` call.
|
for its `build` call.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, list):
|
||||||
|
inputs = inputs[:]
|
||||||
with K.name_scope(self.name):
|
with K.name_scope(self.name):
|
||||||
# Handle laying building (weight creating, input spec locking).
|
# Handle laying building (weight creating, input spec locking).
|
||||||
if not self.built:
|
if not self.built:
|
||||||
@ -1417,7 +1433,7 @@ class Container(Layer):
|
|||||||
get_weights
|
get_weights
|
||||||
set_weights
|
set_weights
|
||||||
get_config
|
get_config
|
||||||
get_output_shape_for
|
compute_output_shape
|
||||||
|
|
||||||
# Class Methods
|
# Class Methods
|
||||||
from_config
|
from_config
|
||||||
@ -2029,7 +2045,7 @@ class Container(Layer):
|
|||||||
for i in range(len(input_shapes)):
|
for i in range(len(input_shapes)):
|
||||||
layer = self.input_layers[i]
|
layer = self.input_layers[i]
|
||||||
input_shape = input_shapes[i]
|
input_shape = input_shapes[i]
|
||||||
# It's an input layer: get_output_shape_for is identity,
|
# It's an input layer: compute_output_shape is identity,
|
||||||
# and there is only one node and one tensor output.
|
# and there is only one node and one tensor output.
|
||||||
shape_key = layer.name + '_0_0'
|
shape_key = layer.name + '_0_0'
|
||||||
layers_to_output_shapes[shape_key] = input_shape
|
layers_to_output_shapes[shape_key] = input_shape
|
||||||
|
@ -733,11 +733,12 @@ class Model(Container):
|
|||||||
loss_functions = []
|
loss_functions = []
|
||||||
for name in self.output_names:
|
for name in self.output_names:
|
||||||
if name not in loss:
|
if name not in loss:
|
||||||
warnings.warn('Output "' + name + '" missing from loss dictionary. '
|
warnings.warn(
|
||||||
'We assume this was done on purpose, '
|
'Output "' + name + '" missing from loss dictionary. '
|
||||||
'and we will not be expecting '
|
'We assume this was done on purpose, '
|
||||||
'any data to be passed to "' + name +
|
'and we will not be expecting '
|
||||||
'" during training.')
|
'any data to be passed to "' + name + '" during training.',
|
||||||
|
stacklevel=2)
|
||||||
loss_functions.append(losses.get(loss.get(name)))
|
loss_functions.append(losses.get(loss.get(name)))
|
||||||
elif isinstance(loss, list):
|
elif isinstance(loss, list):
|
||||||
if len(loss) != len(self.outputs):
|
if len(loss) != len(self.outputs):
|
||||||
@ -1202,7 +1203,7 @@ class Model(Container):
|
|||||||
if batch_index == 0:
|
if batch_index == 0:
|
||||||
for batch_out in batch_outs:
|
for batch_out in batch_outs:
|
||||||
shape = (samples,) + batch_out.shape[1:]
|
shape = (samples,) + batch_out.shape[1:]
|
||||||
outs.append(np.zeros(shape, dtype=K.floatx()))
|
outs.append(np.zeros(shape, dtype=batch_out.dtype))
|
||||||
|
|
||||||
for i, batch_out in enumerate(batch_outs):
|
for i, batch_out in enumerate(batch_outs):
|
||||||
outs[i][batch_start:batch_end] = batch_out
|
outs[i][batch_start:batch_end] = batch_out
|
||||||
@ -1718,7 +1719,7 @@ class Model(Container):
|
|||||||
- a tuple (inputs, targets, sample_weights).
|
- a tuple (inputs, targets, sample_weights).
|
||||||
All arrays should contain the same number of samples.
|
All arrays should contain the same number of samples.
|
||||||
The generator is expected to loop over its data
|
The generator is expected to loop over its data
|
||||||
indefinitely. An epoch finishes when `samples_per_epoch`
|
indefinitely. An epoch finishes when `steps_per_epoch`
|
||||||
samples have been seen by the model.
|
samples have been seen by the model.
|
||||||
steps_per_epoch: Total number of steps (batches of samples)
|
steps_per_epoch: Total number of steps (batches of samples)
|
||||||
to yield from `generator` before declaring one epoch
|
to yield from `generator` before declaring one epoch
|
||||||
@ -1767,7 +1768,7 @@ class Model(Container):
|
|||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
|
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
|
||||||
samples_per_epoch=10000, epochs=10)
|
steps_per_epoch=10000, epochs=10)
|
||||||
```
|
```
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -2028,7 +2029,8 @@ class Model(Container):
|
|||||||
steps,
|
steps,
|
||||||
max_q_size=10,
|
max_q_size=10,
|
||||||
workers=1,
|
workers=1,
|
||||||
pickle_safe=False):
|
pickle_safe=False,
|
||||||
|
verbose=0):
|
||||||
"""Generates predictions for the input samples from a data generator.
|
"""Generates predictions for the input samples from a data generator.
|
||||||
|
|
||||||
The generator should return the same kind of data as accepted by
|
The generator should return the same kind of data as accepted by
|
||||||
@ -2048,6 +2050,7 @@ class Model(Container):
|
|||||||
non picklable arguments to the generator
|
non picklable arguments to the generator
|
||||||
as they can't be passed
|
as they can't be passed
|
||||||
easily to children processes.
|
easily to children processes.
|
||||||
|
verbose: verbosity mode, 0 or 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numpy array(s) of predictions.
|
Numpy array(s) of predictions.
|
||||||
@ -2067,6 +2070,9 @@ class Model(Container):
|
|||||||
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
|
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
|
||||||
enqueuer.start(workers=workers, max_q_size=max_q_size)
|
enqueuer.start(workers=workers, max_q_size=max_q_size)
|
||||||
|
|
||||||
|
if verbose == 1:
|
||||||
|
progbar = Progbar(target=steps)
|
||||||
|
|
||||||
while steps_done < steps:
|
while steps_done < steps:
|
||||||
generator_output = None
|
generator_output = None
|
||||||
while enqueuer.is_running():
|
while enqueuer.is_running():
|
||||||
@ -2103,6 +2109,8 @@ class Model(Container):
|
|||||||
for i, out in enumerate(outs):
|
for i, out in enumerate(outs):
|
||||||
all_outs[i].append(out)
|
all_outs[i].append(out)
|
||||||
steps_done += 1
|
steps_done += 1
|
||||||
|
if verbose == 1:
|
||||||
|
progbar.update(steps_done)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if enqueuer is not None:
|
if enqueuer is not None:
|
||||||
|
@ -45,14 +45,16 @@ class Initializer(object):
|
|||||||
|
|
||||||
|
|
||||||
class Zeros(Initializer):
|
class Zeros(Initializer):
|
||||||
"""Initializer that generates tensors initialized to 0."""
|
"""Initializer that generates tensors initialized to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
def __call__(self, shape, dtype=None):
|
def __call__(self, shape, dtype=None):
|
||||||
return K.constant(0, shape=shape, dtype=dtype)
|
return K.constant(0, shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class Ones(Initializer):
|
class Ones(Initializer):
|
||||||
"""Initializer that generates tensors initialized to 1."""
|
"""Initializer that generates tensors initialized to 1.
|
||||||
|
"""
|
||||||
|
|
||||||
def __call__(self, shape, dtype=None):
|
def __call__(self, shape, dtype=None):
|
||||||
return K.constant(1, shape=shape, dtype=dtype)
|
return K.constant(1, shape=shape, dtype=dtype)
|
||||||
@ -130,7 +132,7 @@ class RandomUniform(Initializer):
|
|||||||
class TruncatedNormal(Initializer):
|
class TruncatedNormal(Initializer):
|
||||||
"""Initializer that generates a truncated normal distribution.
|
"""Initializer that generates a truncated normal distribution.
|
||||||
|
|
||||||
These values are similar to values from a `random_normal_initializer`
|
These values are similar to values from a `RandomNormal`
|
||||||
except that values more than two standard deviations from the mean
|
except that values more than two standard deviations from the mean
|
||||||
are discarded and re-drawn. This is the recommended initializer for
|
are discarded and re-drawn. This is the recommended initializer for
|
||||||
neural network weights and filters.
|
neural network weights and filters.
|
||||||
@ -161,6 +163,7 @@ class VarianceScaling(Initializer):
|
|||||||
|
|
||||||
With `distribution="normal"`, samples are drawn from a truncated normal
|
With `distribution="normal"`, samples are drawn from a truncated normal
|
||||||
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
|
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
|
||||||
|
|
||||||
- number of input units in the weight tensor, if mode = "fan_in"
|
- number of input units in the weight tensor, if mode = "fan_in"
|
||||||
- number of output units, if mode = "fan_out"
|
- number of output units, if mode = "fan_out"
|
||||||
- average of the numbers of input and output units, if mode = "fan_avg"
|
- average of the numbers of input and output units, if mode = "fan_avg"
|
||||||
|
@ -244,7 +244,7 @@ class _Conv(Layer):
|
|||||||
'kernel_initializer':
|
'kernel_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.kernel_initializer),
|
||||||
'bias_initializer':
|
'bias_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.bias_initializer),
|
||||||
'kernel_regularizer':
|
'kernel_regularizer':
|
||||||
regularizers.serialize(self.kernel_regularizer),
|
regularizers.serialize(self.kernel_regularizer),
|
||||||
'bias_regularizer':
|
'bias_regularizer':
|
||||||
@ -289,7 +289,7 @@ class Conv1D(_Conv):
|
|||||||
any `dilation_rate` value != 1.
|
any `dilation_rate` value != 1.
|
||||||
padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive).
|
padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive).
|
||||||
`"causal"` results in causal (dilated) convolutions, e.g. output[t]
|
`"causal"` results in causal (dilated) convolutions, e.g. output[t]
|
||||||
depends solely on input[:t-1]. Useful when modeling temporal data
|
does not depend on input[t+1:]. Useful when modeling temporal data
|
||||||
where the model should not violate the temporal order.
|
where the model should not violate the temporal order.
|
||||||
See [WaveNet: A Generative Model for Raw Audio, section
|
See [WaveNet: A Generative Model for Raw Audio, section
|
||||||
2.1](https://arxiv.org/abs/1609.03499).
|
2.1](https://arxiv.org/abs/1609.03499).
|
||||||
@ -395,9 +395,9 @@ class Conv2D(_Conv):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -621,7 +621,7 @@ class Conv2DTranspose(Conv2D):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filters: Integer, the dimensionality of the output space
|
filters: Integer, the dimensionality of the output space
|
||||||
(i.e. the number output of filters in the convolution).
|
(i.e. the number of output filters in the convolution).
|
||||||
kernel_size: An integer or tuple/list of 2 integers, specifying the
|
kernel_size: An integer or tuple/list of 2 integers, specifying the
|
||||||
width and height of the 2D convolution window.
|
width and height of the 2D convolution window.
|
||||||
Can be a single integer to specify the same value for
|
Can be a single integer to specify the same value for
|
||||||
@ -637,9 +637,9 @@ class Conv2DTranspose(Conv2D):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -688,7 +688,7 @@ class Conv2DTranspose(Conv2D):
|
|||||||
kernel_size,
|
kernel_size,
|
||||||
strides=(1, 1),
|
strides=(1, 1),
|
||||||
padding='valid',
|
padding='valid',
|
||||||
data_format='channels_last',
|
data_format=None,
|
||||||
activation=None,
|
activation=None,
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
kernel_initializer='glorot_uniform',
|
kernel_initializer='glorot_uniform',
|
||||||
@ -845,9 +845,9 @@ class SeparableConv2D(Conv2D):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -1079,9 +1079,9 @@ class UpSampling2D(Layer):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -1257,7 +1257,7 @@ class ZeroPadding2D(Layer):
|
|||||||
- If tuple of 2 ints:
|
- If tuple of 2 ints:
|
||||||
interpreted as two different
|
interpreted as two different
|
||||||
symmetric padding values for height and width:
|
symmetric padding values for height and width:
|
||||||
`(symmetric_height_pad, symmetrc_width_pad)`.
|
`(symmetric_height_pad, symmetric_width_pad)`.
|
||||||
- If tuple of 2 tuples of 2 ints:
|
- If tuple of 2 tuples of 2 ints:
|
||||||
interpreted as
|
interpreted as
|
||||||
`((top_pad, bottom_pad), (left_pad, right_pad))`
|
`((top_pad, bottom_pad), (left_pad, right_pad))`
|
||||||
@ -1265,9 +1265,9 @@ class ZeroPadding2D(Layer):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -1498,7 +1498,7 @@ class Cropping2D(Layer):
|
|||||||
- If tuple of 2 ints:
|
- If tuple of 2 ints:
|
||||||
interpreted as two different
|
interpreted as two different
|
||||||
symmetric cropping values for height and width:
|
symmetric cropping values for height and width:
|
||||||
`(symmetric_height_crop, symmetrc_width_crop)`.
|
`(symmetric_height_crop, symmetric_width_crop)`.
|
||||||
- If tuple of 2 tuples of 2 ints:
|
- If tuple of 2 tuples of 2 ints:
|
||||||
interpreted as
|
interpreted as
|
||||||
`((top_crop, bottom_crop), (left_crop, right_crop))`
|
`((top_crop, bottom_crop), (left_crop, right_crop))`
|
||||||
@ -1506,9 +1506,9 @@ class Cropping2D(Layer):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
|
@ -357,7 +357,7 @@ class ConvLSTM2D(ConvRecurrent2D):
|
|||||||
self.states = [None, None]
|
self.states = [None, None]
|
||||||
|
|
||||||
if self.data_format == 'channels_first':
|
if self.data_format == 'channels_first':
|
||||||
channel_axis = 1
|
channel_axis = 2
|
||||||
else:
|
else:
|
||||||
channel_axis = -1
|
channel_axis = -1
|
||||||
if input_shape[channel_axis] is None:
|
if input_shape[channel_axis] is None:
|
||||||
|
@ -88,7 +88,7 @@ class Dropout(Layer):
|
|||||||
"""Applies Dropout to the input.
|
"""Applies Dropout to the input.
|
||||||
|
|
||||||
Dropout consists in randomly setting
|
Dropout consists in randomly setting
|
||||||
a fraction `p` of input units to 0 at each update during training time,
|
a fraction `rate` of input units to 0 at each update during training time,
|
||||||
which helps prevent overfitting.
|
which helps prevent overfitting.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -140,7 +140,7 @@ class SpatialDropout1D(Dropout):
|
|||||||
between feature maps and should be used instead.
|
between feature maps and should be used instead.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
p: float between 0 and 1. Fraction of the input units to drop.
|
rate: float between 0 and 1. Fraction of the input units to drop.
|
||||||
|
|
||||||
Input shape:
|
Input shape:
|
||||||
3D tensor with shape:
|
3D tensor with shape:
|
||||||
@ -775,7 +775,7 @@ class Dense(Layer):
|
|||||||
'kernel_initializer':
|
'kernel_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.kernel_initializer),
|
||||||
'bias_initializer':
|
'bias_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.bias_initializer),
|
||||||
'kernel_regularizer':
|
'kernel_regularizer':
|
||||||
regularizers.serialize(self.kernel_regularizer),
|
regularizers.serialize(self.kernel_regularizer),
|
||||||
'bias_regularizer':
|
'bias_regularizer':
|
||||||
|
@ -59,7 +59,8 @@ class LocallyConnected1D(Layer):
|
|||||||
specifying the stride length of the convolution.
|
specifying the stride length of the convolution.
|
||||||
Specifying any stride value != 1 is incompatible with specifying
|
Specifying any stride value != 1 is incompatible with specifying
|
||||||
any `dilation_rate` value != 1.
|
any `dilation_rate` value != 1.
|
||||||
padding: One of `"valid"` or `"same"` (case-insensitive).
|
padding: Currently only supports `"valid"` (case-insensitive).
|
||||||
|
`"same"` may be supported in the future.
|
||||||
activation: Activation function to use.
|
activation: Activation function to use.
|
||||||
If you don't specify anything, no activation is applied
|
If you don't specify anything, no activation is applied
|
||||||
(ie. "linear" activation: `a(x) = x`).
|
(ie. "linear" activation: `a(x) = x`).
|
||||||
@ -188,7 +189,7 @@ class LocallyConnected1D(Layer):
|
|||||||
'kernel_initializer':
|
'kernel_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.kernel_initializer),
|
||||||
'bias_initializer':
|
'bias_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.bias_initializer),
|
||||||
'kernel_regularizer':
|
'kernel_regularizer':
|
||||||
regularizers.serialize(self.kernel_regularizer),
|
regularizers.serialize(self.kernel_regularizer),
|
||||||
'bias_regularizer':
|
'bias_regularizer':
|
||||||
@ -239,16 +240,15 @@ class LocallyConnected2D(Layer):
|
|||||||
specifying the strides of the convolution along the width and height.
|
specifying the strides of the convolution along the width and height.
|
||||||
Can be a single integer to specify the same value for
|
Can be a single integer to specify the same value for
|
||||||
all spatial dimensions.
|
all spatial dimensions.
|
||||||
Specifying any stride value != 1 is incompatible with specifying
|
padding: Currently only support `"valid"` (case-insensitive).
|
||||||
any `dilation_rate` value != 1.
|
`"same"` will be supported in future.
|
||||||
padding: one of `"valid"` or `"same"` (case-insensitive).
|
|
||||||
data_format: A string,
|
data_format: A string,
|
||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -460,7 +460,7 @@ class LocallyConnected2D(Layer):
|
|||||||
'kernel_initializer':
|
'kernel_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.kernel_initializer),
|
||||||
'bias_initializer':
|
'bias_initializer':
|
||||||
initializers.serialize(self.kernel_initializer),
|
initializers.serialize(self.bias_initializer),
|
||||||
'kernel_regularizer':
|
'kernel_regularizer':
|
||||||
regularizers.serialize(self.kernel_regularizer),
|
regularizers.serialize(self.kernel_regularizer),
|
||||||
'bias_regularizer':
|
'bias_regularizer':
|
||||||
|
@ -41,6 +41,44 @@ class _Merge(Layer):
|
|||||||
def _merge_function(self, inputs):
|
def _merge_function(self, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _compute_elemwise_op_output_shape(self, shape1, shape2):
|
||||||
|
"""Computes the shape of the resultant of an elementwise operation.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
shape1: tuple or None. Shape of the first tensor
|
||||||
|
shape2: tuple or None. Shape of the second tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
expected output shape when an element-wise operation is
|
||||||
|
carried out on 2 tensors with shapes shape1 and shape2.
|
||||||
|
tuple or None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if shape1 and shape2 are not compatible for
|
||||||
|
element-wise operations.
|
||||||
|
"""
|
||||||
|
if None in [shape1, shape2]:
|
||||||
|
return None
|
||||||
|
elif len(shape1) < len(shape2):
|
||||||
|
return self._compute_elemwise_op_output_shape(shape2, shape1)
|
||||||
|
elif not shape2:
|
||||||
|
return shape1
|
||||||
|
output_shape = list(shape1[:-len(shape2)])
|
||||||
|
for i, j in zip(shape1[-len(shape2):], shape2):
|
||||||
|
if i is None or j is None:
|
||||||
|
output_shape.append(None)
|
||||||
|
elif i == 1:
|
||||||
|
output_shape.append(j)
|
||||||
|
elif j == 1:
|
||||||
|
output_shape.append(i)
|
||||||
|
else:
|
||||||
|
if i != j:
|
||||||
|
raise ValueError('Operands could not be broadcast '
|
||||||
|
'together with shapes ' + str(shape1) + ' ' + str(
|
||||||
|
shape2))
|
||||||
|
output_shape.append(i)
|
||||||
|
return tuple(output_shape)
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
# Used purely for shape validation.
|
# Used purely for shape validation.
|
||||||
if not isinstance(input_shape, list):
|
if not isinstance(input_shape, list):
|
||||||
@ -49,23 +87,107 @@ class _Merge(Layer):
|
|||||||
raise ValueError('A merge layer should be called '
|
raise ValueError('A merge layer should be called '
|
||||||
'on a list of at least 2 inputs. '
|
'on a list of at least 2 inputs. '
|
||||||
'Got ' + str(len(input_shape)) + ' inputs.')
|
'Got ' + str(len(input_shape)) + ' inputs.')
|
||||||
if all([shape is None for shape in input_shape]):
|
batch_sizes = [s[0] for s in input_shape if s is not None]
|
||||||
return
|
batch_sizes = set(batch_sizes)
|
||||||
input_shapes = [
|
batch_sizes -= set([None])
|
||||||
tuple(tensor_shape.TensorShape(shape).as_list())
|
if len(batch_sizes) > 1:
|
||||||
for shape in input_shape
|
raise ValueError('Can not merge tensors with different '
|
||||||
]
|
'batch sizes. Got tensors with shapes : ' + str(
|
||||||
# TODO(fchollet): handle shapes with None entries.
|
input_shape))
|
||||||
input_shapes_set = set(input_shapes)
|
if input_shape[0] is None:
|
||||||
if None in input_shapes_set:
|
output_shape = None
|
||||||
input_shapes_set.remove(None)
|
else:
|
||||||
if len(input_shapes_set) > 1:
|
output_shape = input_shape[0][1:]
|
||||||
raise ValueError('Only tensors of same shape can '
|
for i in range(1, len(input_shape)):
|
||||||
'be merged by layer' + self.name +
|
if input_shape[i] is None:
|
||||||
' Got input shapes: %s' % input_shapes)
|
shape = None
|
||||||
|
else:
|
||||||
|
shape = input_shape[i][1:]
|
||||||
|
output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
|
||||||
|
# If the inputs have different ranks, we have to reshape them
|
||||||
|
# to make them broadcastable.
|
||||||
|
if None not in input_shape and len(set(map(len, input_shape))) == 1:
|
||||||
|
self._reshape_required = False
|
||||||
|
else:
|
||||||
|
self._reshape_required = True
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
return self._merge_function(inputs)
|
if self._reshape_required:
|
||||||
|
reshaped_inputs = []
|
||||||
|
input_ndims = list(map(K.ndim, inputs))
|
||||||
|
if None not in input_ndims:
|
||||||
|
# If ranks of all inputs are available,
|
||||||
|
# we simply expand each of them at axis=1
|
||||||
|
# until all of them have the same rank.
|
||||||
|
max_ndim = max(input_ndims)
|
||||||
|
for x in inputs:
|
||||||
|
x_ndim = K.ndim(x)
|
||||||
|
for _ in range(max_ndim - x_ndim):
|
||||||
|
x = K.expand_dims(x, 1)
|
||||||
|
reshaped_inputs.append(x)
|
||||||
|
return self._merge_function(reshaped_inputs)
|
||||||
|
else:
|
||||||
|
# Transpose all inputs so that batch size is the last dimension.
|
||||||
|
# (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
|
||||||
|
transposed = False
|
||||||
|
for x in inputs:
|
||||||
|
x_ndim = K.ndim(x)
|
||||||
|
if x_ndim is None:
|
||||||
|
x_shape = K.shape(x)
|
||||||
|
batch_size = x_shape[0]
|
||||||
|
new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
|
||||||
|
x_transposed = K.reshape(x,
|
||||||
|
K.stack([batch_size, K.prod(x_shape[1:])]))
|
||||||
|
x_transposed = K.permute_dimensions(x_transposed, (1, 0))
|
||||||
|
x_transposed = K.reshape(x_transposed, new_shape)
|
||||||
|
reshaped_inputs.append(x_transposed)
|
||||||
|
transposed = True
|
||||||
|
elif x_ndim > 1:
|
||||||
|
dims = list(range(1, x_ndim)) + [0]
|
||||||
|
reshaped_inputs.append(K.permute_dimensions(x, dims))
|
||||||
|
transposed = True
|
||||||
|
else:
|
||||||
|
# We don't transpose inputs if they are 1D vectors or scalars.
|
||||||
|
reshaped_inputs.append(x)
|
||||||
|
y = self._merge_function(reshaped_inputs)
|
||||||
|
y_ndim = K.ndim(y)
|
||||||
|
if transposed:
|
||||||
|
# If inputs have been transposed, we have to transpose the output too.
|
||||||
|
if y_ndim is None:
|
||||||
|
y_shape = K.shape(y)
|
||||||
|
y_ndim = K.shape(y_shape)[0]
|
||||||
|
batch_size = y_shape[y_ndim - 1]
|
||||||
|
new_shape = K.concatenate(
|
||||||
|
[K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
|
||||||
|
y = K.reshape(y, (-1, batch_size))
|
||||||
|
y = K.permute_dimensions(y, (1, 0))
|
||||||
|
y = K.reshape(y, new_shape)
|
||||||
|
elif y_ndim > 1:
|
||||||
|
dims = [y_ndim - 1] + list(range(y_ndim - 1))
|
||||||
|
y = K.permute_dimensions(y, dims)
|
||||||
|
return y
|
||||||
|
else:
|
||||||
|
return self._merge_function(inputs)
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
if input_shape[0] is None:
|
||||||
|
output_shape = None
|
||||||
|
else:
|
||||||
|
output_shape = input_shape[0][1:]
|
||||||
|
for i in range(1, len(input_shape)):
|
||||||
|
if input_shape[i] is None:
|
||||||
|
shape = None
|
||||||
|
else:
|
||||||
|
shape = input_shape[i][1:]
|
||||||
|
output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
|
||||||
|
batch_sizes = [s[0] for s in input_shape if s is not None]
|
||||||
|
batch_sizes = set(batch_sizes)
|
||||||
|
batch_sizes -= set([None])
|
||||||
|
if len(batch_sizes) == 1:
|
||||||
|
output_shape = (list(batch_sizes)[0],) + output_shape
|
||||||
|
else:
|
||||||
|
output_shape = (None,) + output_shape
|
||||||
|
return output_shape
|
||||||
|
|
||||||
def compute_mask(self, inputs, mask=None):
|
def compute_mask(self, inputs, mask=None):
|
||||||
if mask is None:
|
if mask is None:
|
||||||
@ -219,8 +341,8 @@ class Concatenate(_Merge):
|
|||||||
for input_i, mask_i in zip(inputs, mask):
|
for input_i, mask_i in zip(inputs, mask):
|
||||||
if mask_i is None:
|
if mask_i is None:
|
||||||
# Input is unmasked. Append all 1s to masks,
|
# Input is unmasked. Append all 1s to masks,
|
||||||
# but cast it to uint8 first
|
# but cast it to bool first
|
||||||
masks.append(K.cast(K.ones_like(input_i), 'uint8'))
|
masks.append(K.cast(K.ones_like(input_i), 'bool'))
|
||||||
elif K.ndim(mask_i) < K.ndim(input_i):
|
elif K.ndim(mask_i) < K.ndim(input_i):
|
||||||
# Mask is smaller than the input, expand it
|
# Mask is smaller than the input, expand it
|
||||||
masks.append(K.expand_dims(mask_i))
|
masks.append(K.expand_dims(mask_i))
|
||||||
|
@ -154,7 +154,7 @@ class BatchNormalization(Layer):
|
|||||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||||
|
|
||||||
# Determines whether broadcasting is needed.
|
# Determines whether broadcasting is needed.
|
||||||
needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])
|
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
|
||||||
|
|
||||||
normed, mean, variance = K.normalize_batch_in_training(
|
normed, mean, variance = K.normalize_batch_in_training(
|
||||||
inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)
|
inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)
|
||||||
|
@ -199,9 +199,9 @@ class MaxPooling2D(_Pooling2D):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -255,9 +255,9 @@ class AveragePooling2D(_Pooling2D):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -542,9 +542,9 @@ class GlobalAveragePooling2D(_GlobalPooling2D):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
@ -577,9 +577,9 @@ class GlobalMaxPooling2D(_GlobalPooling2D):
|
|||||||
one of `channels_last` (default) or `channels_first`.
|
one of `channels_last` (default) or `channels_first`.
|
||||||
The ordering of the dimensions in the inputs.
|
The ordering of the dimensions in the inputs.
|
||||||
`channels_last` corresponds to inputs with shape
|
`channels_last` corresponds to inputs with shape
|
||||||
`(batch, width, height, channels)` while `channels_first`
|
`(batch, height, width, channels)` while `channels_first`
|
||||||
corresponds to inputs with shape
|
corresponds to inputs with shape
|
||||||
`(batch, channels, width, height)`.
|
`(batch, channels, height, width)`.
|
||||||
It defaults to the `image_data_format` value found in your
|
It defaults to the `image_data_format` value found in your
|
||||||
Keras config file at `~/.keras/keras.json`.
|
Keras config file at `~/.keras/keras.json`.
|
||||||
If you never set it, then it will be "channels_last".
|
If you never set it, then it will be "channels_last".
|
||||||
|
@ -105,8 +105,16 @@ class Recurrent(Layer):
|
|||||||
# now model.output_shape == (None, 32)
|
# now model.output_shape == (None, 32)
|
||||||
# note: `None` is the batch dimension.
|
# note: `None` is the batch dimension.
|
||||||
|
|
||||||
# for subsequent layers, not need to specify the input size:
|
# for subsequent layers, no need to specify the input size:
|
||||||
model.add(LSTM(16))
|
model.add(LSTM(16))
|
||||||
|
|
||||||
|
# to stack recurrent layers, you must use return_sequences=True
|
||||||
|
# on any recurrent layer that feeds into another recurrent layer.
|
||||||
|
# note that you only need to specify the input size on the first layer.
|
||||||
|
model = Sequential()
|
||||||
|
model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True))
|
||||||
|
model.add(LSTM(32, return_sequences=True))
|
||||||
|
model.add(LSTM(10))
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -116,7 +124,8 @@ class Recurrent(Layer):
|
|||||||
return_sequences: Boolean. Whether to return the last output
|
return_sequences: Boolean. Whether to return the last output
|
||||||
in the output sequence, or the full sequence.
|
in the output sequence, or the full sequence.
|
||||||
go_backwards: Boolean (default False).
|
go_backwards: Boolean (default False).
|
||||||
If True, process the input sequence backwards.
|
If True, process the input sequence backwards and return the
|
||||||
|
reversed sequence.
|
||||||
stateful: Boolean (default False). If True, the last state
|
stateful: Boolean (default False). If True, the last state
|
||||||
for each sample at index i in a batch will be used as initial
|
for each sample at index i in a batch will be used as initial
|
||||||
state for the sample of index i in the following batch.
|
state for the sample of index i in the following batch.
|
||||||
@ -398,6 +407,7 @@ class SimpleRNN(Recurrent):
|
|||||||
units: Positive integer, dimensionality of the output space.
|
units: Positive integer, dimensionality of the output space.
|
||||||
activation: Activation function to use.
|
activation: Activation function to use.
|
||||||
If you don't specify anything, no activation is applied
|
If you don't specify anything, no activation is applied
|
||||||
|
If you pass None, no activation is applied
|
||||||
(ie. "linear" activation: `a(x) = x`).
|
(ie. "linear" activation: `a(x) = x`).
|
||||||
use_bias: Boolean, whether the layer uses a bias vector.
|
use_bias: Boolean, whether the layer uses a bias vector.
|
||||||
kernel_initializer: Initializer for the `kernel` weights matrix,
|
kernel_initializer: Initializer for the `kernel` weights matrix,
|
||||||
@ -547,7 +557,7 @@ class SimpleRNN(Recurrent):
|
|||||||
|
|
||||||
def get_constants(self, inputs, training=None):
|
def get_constants(self, inputs, training=None):
|
||||||
constants = []
|
constants = []
|
||||||
if self.implementation == 0 and 0 < self.dropout < 1:
|
if self.implementation != 0 and 0 < self.dropout < 1:
|
||||||
input_shape = K.int_shape(inputs)
|
input_shape = K.int_shape(inputs)
|
||||||
input_dim = input_shape[-1]
|
input_dim = input_shape[-1]
|
||||||
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
||||||
@ -619,7 +629,7 @@ class GRU(Recurrent):
|
|||||||
Arguments:
|
Arguments:
|
||||||
units: Positive integer, dimensionality of the output space.
|
units: Positive integer, dimensionality of the output space.
|
||||||
activation: Activation function to use.
|
activation: Activation function to use.
|
||||||
If you don't specify anything, no activation is applied
|
If you pass None, no activation is applied
|
||||||
(ie. "linear" activation: `a(x) = x`).
|
(ie. "linear" activation: `a(x) = x`).
|
||||||
recurrent_activation: Activation function to use
|
recurrent_activation: Activation function to use
|
||||||
for the recurrent step.
|
for the recurrent step.
|
||||||
@ -792,7 +802,7 @@ class GRU(Recurrent):
|
|||||||
|
|
||||||
def get_constants(self, inputs, training=None):
|
def get_constants(self, inputs, training=None):
|
||||||
constants = []
|
constants = []
|
||||||
if self.implementation == 0 and 0 < self.dropout < 1:
|
if self.implementation != 0 and 0 < self.dropout < 1:
|
||||||
input_shape = K.int_shape(inputs)
|
input_shape = K.int_shape(inputs)
|
||||||
input_dim = input_shape[-1]
|
input_dim = input_shape[-1]
|
||||||
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
||||||
@ -861,7 +871,7 @@ class GRU(Recurrent):
|
|||||||
if self.use_bias:
|
if self.use_bias:
|
||||||
x_z = K.bias_add(x_z, self.bias_z)
|
x_z = K.bias_add(x_z, self.bias_z)
|
||||||
x_r = K.bias_add(x_r, self.bias_r)
|
x_r = K.bias_add(x_r, self.bias_r)
|
||||||
x_h = K.bias_add(x_r, self.bias_h)
|
x_h = K.bias_add(x_h, self.bias_h)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown `implementation` mode.')
|
raise ValueError('Unknown `implementation` mode.')
|
||||||
z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
|
z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
|
||||||
@ -924,7 +934,7 @@ class LSTM(Recurrent):
|
|||||||
Arguments:
|
Arguments:
|
||||||
units: Positive integer, dimensionality of the output space.
|
units: Positive integer, dimensionality of the output space.
|
||||||
activation: Activation function to use.
|
activation: Activation function to use.
|
||||||
If you don't specify anything, no activation is applied
|
If you pass None, no activation is applied
|
||||||
(ie. "linear" activation: `a(x) = x`).
|
(ie. "linear" activation: `a(x) = x`).
|
||||||
recurrent_activation: Activation function to use
|
recurrent_activation: Activation function to use
|
||||||
for the recurrent step.
|
for the recurrent step.
|
||||||
@ -1127,7 +1137,7 @@ class LSTM(Recurrent):
|
|||||||
|
|
||||||
def get_constants(self, inputs, training=None):
|
def get_constants(self, inputs, training=None):
|
||||||
constants = []
|
constants = []
|
||||||
if self.implementation == 0 and 0 < self.dropout < 1:
|
if self.implementation != 0 and 0 < self.dropout < 1:
|
||||||
input_shape = K.int_shape(inputs)
|
input_shape = K.int_shape(inputs)
|
||||||
input_dim = input_shape[-1]
|
input_dim = input_shape[-1]
|
||||||
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
# pylint: disable=protected-access
|
||||||
"""Wrapper layers: layers that augment the functionality of another layer.
|
"""Wrapper layers: layers that augment the functionality of another layer.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -19,6 +20,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
|
|
||||||
from tensorflow.contrib.keras.python.keras import backend as K
|
from tensorflow.contrib.keras.python.keras import backend as K
|
||||||
from tensorflow.contrib.keras.python.keras.engine import InputSpec
|
from tensorflow.contrib.keras.python.keras.engine import InputSpec
|
||||||
@ -70,9 +72,10 @@ class Wrapper(Layer):
|
|||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config, custom_objects=None):
|
||||||
from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||||
layer = deserialize_layer(config.pop('layer'))
|
layer = deserialize_layer(
|
||||||
|
config.pop('layer'), custom_objects=custom_objects)
|
||||||
return cls(layer, **config)
|
return cls(layer, **config)
|
||||||
|
|
||||||
|
|
||||||
@ -188,12 +191,15 @@ class Bidirectional(Wrapper):
|
|||||||
If None, the outputs will not be combined,
|
If None, the outputs will not be combined,
|
||||||
they will be returned as a list.
|
they will be returned as a list.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: In case of invalid `merge_mode` argument.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5,
|
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5,
|
||||||
10)))
|
10)))
|
||||||
model.add(Bidirectional(LSTM(10)))
|
model.add(Bidirectional(LSTM(10)))
|
||||||
model.add(Dense(5))
|
model.add(Dense(5))
|
||||||
model.add(Activation('softmax'))
|
model.add(Activation('softmax'))
|
||||||
@ -242,29 +248,47 @@ class Bidirectional(Wrapper):
|
|||||||
shape = self.forward_layer._compute_output_shape(input_shape) # pylint: disable=protected-access
|
shape = self.forward_layer._compute_output_shape(input_shape) # pylint: disable=protected-access
|
||||||
return [shape, copy.copy(shape)]
|
return [shape, copy.copy(shape)]
|
||||||
|
|
||||||
def call(self, inputs, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
y = self.forward_layer.call(inputs, mask)
|
kwargs = {}
|
||||||
y_rev = self.backward_layer.call(inputs, mask)
|
func_args = inspect.getargspec(self.layer.call).args
|
||||||
|
if 'training' in func_args:
|
||||||
|
kwargs['training'] = training
|
||||||
|
if 'mask' in func_args:
|
||||||
|
kwargs['mask'] = mask
|
||||||
|
|
||||||
|
y = self.forward_layer.call(inputs, **kwargs)
|
||||||
|
y_rev = self.backward_layer.call(inputs, **kwargs)
|
||||||
if self.return_sequences:
|
if self.return_sequences:
|
||||||
y_rev = K.reverse(y_rev, 1)
|
y_rev = K.reverse(y_rev, 1)
|
||||||
if self.merge_mode == 'concat':
|
if self.merge_mode == 'concat':
|
||||||
return K.concatenate([y, y_rev])
|
output = K.concatenate([y, y_rev])
|
||||||
elif self.merge_mode == 'sum':
|
elif self.merge_mode == 'sum':
|
||||||
return y + y_rev
|
output = y + y_rev
|
||||||
elif self.merge_mode == 'ave':
|
elif self.merge_mode == 'ave':
|
||||||
return (y + y_rev) / 2
|
output = (y + y_rev) / 2
|
||||||
elif self.merge_mode == 'mul':
|
elif self.merge_mode == 'mul':
|
||||||
return y * y_rev
|
output = y * y_rev
|
||||||
elif self.merge_mode is None:
|
elif self.merge_mode is None:
|
||||||
return [y, y_rev]
|
output = [y, y_rev]
|
||||||
|
|
||||||
|
# Properly set learning phase
|
||||||
|
if 0 < self.layer.dropout + self.layer.recurrent_dropout:
|
||||||
|
if self.merge_mode is None:
|
||||||
|
for out in output:
|
||||||
|
out._uses_learning_phase = True
|
||||||
|
else:
|
||||||
|
output._uses_learning_phase = True
|
||||||
|
return output
|
||||||
|
|
||||||
def reset_states(self):
|
def reset_states(self):
|
||||||
self.forward_layer.reset_states()
|
self.forward_layer.reset_states()
|
||||||
self.backward_layer.reset_states()
|
self.backward_layer.reset_states()
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.forward_layer.build(input_shape)
|
with K.name_scope(self.forward_layer.name):
|
||||||
self.backward_layer.build(input_shape)
|
self.forward_layer.build(input_shape)
|
||||||
|
with K.name_scope(self.backward_layer.name):
|
||||||
|
self.backward_layer.build(input_shape)
|
||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
def compute_mask(self, inputs, mask):
|
def compute_mask(self, inputs, mask):
|
||||||
|
@ -43,12 +43,15 @@ def binary_accuracy(y_true, y_pred):
|
|||||||
|
|
||||||
|
|
||||||
def categorical_accuracy(y_true, y_pred):
|
def categorical_accuracy(y_true, y_pred):
|
||||||
return K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1))
|
return K.cast(
|
||||||
|
K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx())
|
||||||
|
|
||||||
|
|
||||||
def sparse_categorical_accuracy(y_true, y_pred):
|
def sparse_categorical_accuracy(y_true, y_pred):
|
||||||
return K.equal(
|
return K.cast(
|
||||||
K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), K.floatx()))
|
K.equal(
|
||||||
|
K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1),
|
||||||
|
K.floatx())), K.floatx())
|
||||||
|
|
||||||
|
|
||||||
def top_k_categorical_accuracy(y_true, y_pred, k=5):
|
def top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||||
|
@ -207,7 +207,7 @@ def load_model(filepath, custom_objects=None):
|
|||||||
ValueError: In case of an invalid savefile.
|
ValueError: In case of an invalid savefile.
|
||||||
"""
|
"""
|
||||||
if h5py is None:
|
if h5py is None:
|
||||||
raise ImportError('`save_model` requires h5py.')
|
raise ImportError('`load_model` requires h5py.')
|
||||||
|
|
||||||
if not custom_objects:
|
if not custom_objects:
|
||||||
custom_objects = {}
|
custom_objects = {}
|
||||||
@ -1006,7 +1006,7 @@ class Sequential(Model):
|
|||||||
steps_per_epoch: Total number of steps (batches of samples)
|
steps_per_epoch: Total number of steps (batches of samples)
|
||||||
to yield from `generator` before declaring one epoch
|
to yield from `generator` before declaring one epoch
|
||||||
finished and starting the next epoch. It should typically
|
finished and starting the next epoch. It should typically
|
||||||
be equal to the number of unique samples if your dataset
|
be equal to the number of unique samples of your dataset
|
||||||
divided by the batch size.
|
divided by the batch size.
|
||||||
epochs: Integer, total number of iterations on the data.
|
epochs: Integer, total number of iterations on the data.
|
||||||
verbose: Verbosity mode, 0, 1, or 2.
|
verbose: Verbosity mode, 0, 1, or 2.
|
||||||
@ -1017,8 +1017,10 @@ class Sequential(Model):
|
|||||||
- A tuple (inputs, targets, sample_weights).
|
- A tuple (inputs, targets, sample_weights).
|
||||||
validation_steps: Only relevant if `validation_data`
|
validation_steps: Only relevant if `validation_data`
|
||||||
is a generator.
|
is a generator.
|
||||||
Number of samples to use from validation generator
|
Number of steps to yield from validation generator
|
||||||
at the end of every epoch.
|
at the end of every epoch. It should typically
|
||||||
|
be equal to the number of unique samples of your
|
||||||
|
validation dataset divided by the batch size.
|
||||||
class_weight: Dictionary mapping class indices to a weight
|
class_weight: Dictionary mapping class indices to a weight
|
||||||
for the class.
|
for the class.
|
||||||
max_q_size: Maximum size for the generator queue
|
max_q_size: Maximum size for the generator queue
|
||||||
@ -1050,7 +1052,7 @@ class Sequential(Model):
|
|||||||
# and labels, from each line in the file
|
# and labels, from each line in the file
|
||||||
x, y = process_line(line)
|
x, y = process_line(line)
|
||||||
yield (x, y)
|
yield (x, y)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
|
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
|
||||||
samples_per_epoch=10000, epochs=10)
|
samples_per_epoch=10000, epochs=10)
|
||||||
@ -1119,7 +1121,8 @@ class Sequential(Model):
|
|||||||
steps,
|
steps,
|
||||||
max_q_size=10,
|
max_q_size=10,
|
||||||
workers=1,
|
workers=1,
|
||||||
pickle_safe=False):
|
pickle_safe=False,
|
||||||
|
verbose=0):
|
||||||
"""Generates predictions for the input samples from a data generator.
|
"""Generates predictions for the input samples from a data generator.
|
||||||
|
|
||||||
The generator should return the same kind of data as accepted by
|
The generator should return the same kind of data as accepted by
|
||||||
@ -1136,6 +1139,7 @@ class Sequential(Model):
|
|||||||
relies on multiprocessing, you should not pass
|
relies on multiprocessing, you should not pass
|
||||||
non picklable arguments to the generator
|
non picklable arguments to the generator
|
||||||
as they can't be passed easily to children processes.
|
as they can't be passed easily to children processes.
|
||||||
|
verbose: verbosity mode, 0 or 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Numpy array of predictions.
|
A Numpy array of predictions.
|
||||||
@ -1147,7 +1151,8 @@ class Sequential(Model):
|
|||||||
steps,
|
steps,
|
||||||
max_q_size=max_q_size,
|
max_q_size=max_q_size,
|
||||||
workers=workers,
|
workers=workers,
|
||||||
pickle_safe=pickle_safe)
|
pickle_safe=pickle_safe,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = []
|
config = []
|
||||||
@ -1159,9 +1164,9 @@ class Sequential(Model):
|
|||||||
return copy.deepcopy(config)
|
return copy.deepcopy(config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config, custom_objects=None):
|
||||||
model = cls()
|
model = cls()
|
||||||
for conf in config:
|
for conf in config:
|
||||||
layer = layer_module.deserialize(conf)
|
layer = layer_module.deserialize(conf, custom_objects=custom_objects)
|
||||||
model.add(layer)
|
model.add(layer)
|
||||||
return model
|
return model
|
||||||
|
@ -785,7 +785,7 @@ class Iterator(object):
|
|||||||
index_array = np.random.permutation(n)
|
index_array = np.random.permutation(n)
|
||||||
|
|
||||||
current_index = (self.batch_index * batch_size) % n
|
current_index = (self.batch_index * batch_size) % n
|
||||||
if n >= current_index + batch_size:
|
if n > current_index + batch_size:
|
||||||
current_batch_size = batch_size
|
current_batch_size = batch_size
|
||||||
self.batch_index += 1
|
self.batch_index += 1
|
||||||
else:
|
else:
|
||||||
|
@ -172,7 +172,8 @@ def deserialize_keras_object(identifier,
|
|||||||
else:
|
else:
|
||||||
fn = module_objects.get(function_name)
|
fn = module_objects.get(function_name)
|
||||||
if fn is None:
|
if fn is None:
|
||||||
raise ValueError('Unknown ' + printable_module_name, ':' + class_name)
|
raise ValueError('Unknown ' + printable_module_name,
|
||||||
|
':' + function_name)
|
||||||
return fn
|
return fn
|
||||||
else:
|
else:
|
||||||
raise ValueError('Could not interpret serialized ' + printable_module_name +
|
raise ValueError('Could not interpret serialized ' + printable_module_name +
|
||||||
@ -215,6 +216,8 @@ def func_load(code, defaults=None, closure=None, globs=None):
|
|||||||
"""
|
"""
|
||||||
if isinstance(code, (tuple, list)): # unpack previous dump
|
if isinstance(code, (tuple, list)): # unpack previous dump
|
||||||
code, defaults, closure = code
|
code, defaults, closure = code
|
||||||
|
if isinstance(defaults, list):
|
||||||
|
defaults = tuple(defaults)
|
||||||
code = marshal.loads(code.encode('raw_unicode_escape'))
|
code = marshal.loads(code.encode('raw_unicode_escape'))
|
||||||
if globs is None:
|
if globs is None:
|
||||||
globs = globals()
|
globs = globals()
|
||||||
|
@ -171,7 +171,7 @@ def count_total_params(layers, layer_set=None):
|
|||||||
[K.count_params(p) for p in layer.trainable_weights])
|
[K.count_params(p) for p in layer.trainable_weights])
|
||||||
non_trainable_count += np.sum(
|
non_trainable_count += np.sum(
|
||||||
[K.count_params(p) for p in layer.non_trainable_weights])
|
[K.count_params(p) for p in layer.non_trainable_weights])
|
||||||
return trainable_count, non_trainable_count
|
return int(trainable_count), int(non_trainable_count)
|
||||||
|
|
||||||
|
|
||||||
def convert_all_kernels_in_model(model):
|
def convert_all_kernels_in_model(model):
|
||||||
|
@ -194,6 +194,36 @@ class KerasClassifier(BaseWrapper):
|
|||||||
"""Implementation of the scikit-learn classifier API for Keras.
|
"""Implementation of the scikit-learn classifier API for Keras.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def fit(self, x, y, **kwargs):
|
||||||
|
"""Constructs a new model with `build_fn` & fit the model to `(x, y)`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x : array-like, shape `(n_samples, n_features)`
|
||||||
|
Training samples where n_samples in the number of samples
|
||||||
|
and n_features is the number of features.
|
||||||
|
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
|
||||||
|
True labels for X.
|
||||||
|
**kwargs: dictionary arguments
|
||||||
|
Legal arguments are the arguments of `Sequential.fit`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
history : object
|
||||||
|
details about the training history at each epoch.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: In case of invalid shape for `y` argument.
|
||||||
|
"""
|
||||||
|
y = np.array(y)
|
||||||
|
if len(y.shape) == 2 and y.shape[1] > 1:
|
||||||
|
self.classes_ = np.arange(y.shape[1])
|
||||||
|
elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
|
||||||
|
self.classes_ = np.unique(y)
|
||||||
|
y = np.searchsorted(self.classes_, y)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid shape for y: ' + str(y.shape))
|
||||||
|
self.n_classes_ = len(self.classes_)
|
||||||
|
return super(KerasClassifier, self).fit(x, y, **kwargs)
|
||||||
|
|
||||||
def predict(self, x, **kwargs):
|
def predict(self, x, **kwargs):
|
||||||
"""Returns the class predictions for the given test data.
|
"""Returns the class predictions for the given test data.
|
||||||
|
|
||||||
@ -210,7 +240,8 @@ class KerasClassifier(BaseWrapper):
|
|||||||
Class predictions.
|
Class predictions.
|
||||||
"""
|
"""
|
||||||
kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
|
kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
|
||||||
return self.model.predict_classes(x, **kwargs)
|
classes = self.model.predict_classes(x, **kwargs)
|
||||||
|
return self.classes_[classes]
|
||||||
|
|
||||||
def predict_proba(self, x, **kwargs):
|
def predict_proba(self, x, **kwargs):
|
||||||
"""Returns class probability estimates for the given test data.
|
"""Returns class probability estimates for the given test data.
|
||||||
@ -261,6 +292,7 @@ class KerasClassifier(BaseWrapper):
|
|||||||
compute accuracy. You should pass `metrics=["accuracy"]` to
|
compute accuracy. You should pass `metrics=["accuracy"]` to
|
||||||
the `.compile()` method of the model.
|
the `.compile()` method of the model.
|
||||||
"""
|
"""
|
||||||
|
y = np.searchsorted(self.classes_, y)
|
||||||
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
|
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
|
||||||
|
|
||||||
loss_name = self.model.loss
|
loss_name = self.model.loss
|
||||||
|
Loading…
Reference in New Issue
Block a user