Minor fixes in tf.keras codebase in preparation for Keras 2.2.0 API support.

PiperOrigin-RevId: 200276422
This commit is contained in:
Francois Chollet 2018-06-12 14:03:39 -07:00 committed by TensorFlower Gardener
parent 9c7ba75034
commit abfdf45dcd
12 changed files with 168 additions and 51 deletions

View File

@ -32,7 +32,7 @@ def softmax(x, axis=-1):
"""Softmax activation function.
Arguments:
x : Tensor.
x : Input tensor.
axis: Integer, axis along which the softmax normalization is applied.
Returns:
@ -49,23 +49,45 @@ def softmax(x, axis=-1):
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
return e / s
else:
raise ValueError('Cannot apply softmax to a tensor that is 1D')
raise ValueError('Cannot apply softmax to a tensor that is 1D. '
'Received input: %s' % (x,))
@tf_export('keras.activations.elu')
def elu(x, alpha=1.0):
"""Exponential linear unit.
Arguments:
x: Input tensor.
alpha: A scalar, slope of negative section.
Returns:
The exponential linear activation: `x` if `x > 0` and
`alpha * (exp(x)-1)` if `x < 0`.
Reference:
- [Fast and Accurate Deep Network Learning by Exponential
Linear Units (ELUs)](https://arxiv.org/abs/1511.07289)
"""
return K.elu(x, alpha)
@tf_export('keras.activations.selu')
def selu(x):
"""Scaled Exponential Linear Unit. (Klambauer et al., 2017).
"""Scaled Exponential Linear Unit (SELU).
SELU is equal to: `scale * elu(x, alpha)`, where alpha and scale
are pre-defined constants. The values of `alpha` and `scale` are
chosen so that the mean and variance of the inputs are preserved
between two consecutive layers as long as the weights are initialized
correctly (see `lecun_normal` initialization) and the number of inputs
is "large enough" (see references for more information).
Arguments:
x: A tensor or variable to compute the activation function for.
Returns:
Tensor with the same shape and dtype as `x`.
The scaled exponential unit activation: `scale * elu(x, alpha)`.
# Note
- To be used together with the initialization "lecun_normal".
@ -79,16 +101,44 @@ def selu(x):
@tf_export('keras.activations.softplus')
def softplus(x):
"""Softplus activation function.
Arguments:
x: Input tensor.
Returns:
The softplus activation: `log(exp(x) + 1)`.
"""
return nn.softplus(x)
@tf_export('keras.activations.softsign')
def softsign(x):
"""Softsign activation function.
Arguments:
x: Input tensor.
Returns:
The softplus activation: `x / (abs(x) + 1)`.
"""
return nn.softsign(x)
@tf_export('keras.activations.relu')
def relu(x, alpha=0., max_value=None):
"""Rectified Linear Unit.
Arguments:
x: Input tensor.
alpha: Slope of the negative part. Defaults to zero.
max_value: Maximum value for the output.
Returns:
The (leaky) rectified linear unit activation: `x` if `x > 0`,
`alpha * x` if `x < 0`. If `max_value` is defined, the result
is truncated to this value.
"""
return K.relu(x, alpha=alpha, max_value=max_value)
@ -104,6 +154,19 @@ def sigmoid(x):
@tf_export('keras.activations.hard_sigmoid')
def hard_sigmoid(x):
"""Hard sigmoid activation function.
Faster to compute than sigmoid activation.
Arguments:
x: Input tensor.
Returns:
Hard sigmoid activation:
- `0` if `x < -2.5`
- `1` if `x > 2.5`
- `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`.
"""
return K.hard_sigmoid(x)

View File

@ -2973,30 +2973,29 @@ def rnn(step_function,
Arguments:
step_function: RNN step function.
Parameters;
input; tensor with shape `(samples, ...)` (no time dimension),
Args;
input; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
states; list of tensors.
states; List of tensors.
Returns;
output; tensor with shape `(samples, output_dim)`
output; Tensor with shape `(samples, output_dim)`
(no time dimension).
new_states; list of tensors, same length and shapes
new_states; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
inputs: tensor of temporal data of shape `(samples, time, ...)`
inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D).
initial_states: tensor with shape (samples, output_dim)
initial_states: Tensor with shape `(samples, output_dim)`
(no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over the time
go_backwards: Boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: binary tensor with shape `(samples, time, 1)`,
mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: a list of constant values passed at each step.
unroll: whether to unroll the RNN or to use a symbolic loop
(`while_loop` or `scan` depending on backend).
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
Returns:
@ -3637,12 +3636,12 @@ def _preprocess_conv1d_input(x, data_format):
Returns:
A tensor.
"""
tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
tf_data_format = 'NWC' # to pass TF Conv2dNative operations
if data_format == 'channels_first':
if not _has_nchw_support():
x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC
else:
tf_data_format = 'NCHW'
tf_data_format = 'NCW'
return x, tf_data_format
@ -3741,10 +3740,8 @@ def conv1d(x,
x = temporal_padding(x, (left_pad, 0))
padding = 'valid'
padding = _preprocess_padding(padding)
if data_format == 'channels_last':
tf_data_format = 'NWC'
else:
tf_data_format = 'NCW'
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
x = nn.convolution(
input=x,
filter=kernel,
@ -3752,6 +3749,8 @@ def conv1d(x,
strides=(strides,),
padding=padding,
data_format=tf_data_format)
if data_format == 'channels_first' and tf_data_format == 'NWC':
x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@ -3892,11 +3891,16 @@ def separable_conv1d(x,
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
if isinstance(strides, int):
strides = (strides,)
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,)
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
padding = _preprocess_padding(padding)
if not isinstance(strides, tuple):
strides = tuple(strides)
if tf_data_format == 'NHWC':
if tf_data_format == 'NWC':
spatial_start_dim = 1
strides = (1,) + strides * 2 + (1,)
else:
@ -3918,7 +3922,7 @@ def separable_conv1d(x,
x = array_ops.squeeze(x, [spatial_start_dim])
if data_format == 'channels_first' and tf_data_format == 'NHWC':
if data_format == 'channels_first' and tf_data_format == 'NWC':
x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@ -4717,8 +4721,13 @@ def foldr(fn, elems, initializer=None, name=None):
# Load Keras default configuration from config file if present.
_keras_base_dir = os.path.expanduser('~')
_keras_dir = os.path.join(_keras_base_dir, '.keras')
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
if 'KERAS_HOME' in os.environ:
_keras_dir = os.environ.get('KERAS_HOME')
else:
_keras_base_dir = os.path.expanduser('~')
_keras_dir = os.path.join(_keras_base_dir, '.keras')
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
try:

View File

@ -635,7 +635,11 @@ class LearningRateScheduler(Callback):
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = self.schedule(epoch)
try: # new API
lr = float(K.get_value(self.model.optimizer.lr))
lr = self.schedule(epoch, lr)
except TypeError: # Support for old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')

View File

@ -321,8 +321,26 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=5,
verbose=0)
assert (float(keras.backend.get_value(model.optimizer.lr)) - 0.2
) < keras.backend.epsilon()
assert (
float(keras.backend.get_value(
model.optimizer.lr)) - 0.2) < keras.backend.epsilon()
cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)]
model.compile(
loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=2,
verbose=0)
assert (
float(keras.backend.get_value(
model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
def test_ReduceLROnPlateau(self):
with self.test_session():

View File

@ -185,6 +185,7 @@ def fit_loop(model,
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
if steps_per_epoch is not None:
# Step-wise fit loop.
for step_index in range(steps_per_epoch):
batch_logs = {}
batch_logs['batch'] = step_index
@ -215,7 +216,6 @@ def fit_loop(model,
val_inputs,
val_targets,
sample_weights=val_sample_weights,
batch_size=batch_size,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
@ -224,6 +224,7 @@ def fit_loop(model,
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
else:
# Sample-wise fit loop.
if shuffle == 'batch':
index_array = training_utils.batch_shuffle(index_array, batch_size)
elif shuffle:

View File

@ -382,11 +382,11 @@ class Conv2D(Conv):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@ -613,11 +613,11 @@ class Conv2DTranspose(Conv2D):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@ -1452,11 +1452,11 @@ class SeparableConv2D(SeparableConv):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@ -1596,11 +1596,11 @@ class DepthwiseConv2D(Conv2D):
Arguments:
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@ -2007,7 +2007,7 @@ class ZeroPadding2D(Layer):
Arguments:
padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric padding
is applied to width and height.
is applied to height and width.
- If tuple of 2 ints:
interpreted as two different
symmetric padding values for height and width:
@ -2106,7 +2106,7 @@ class ZeroPadding3D(Layer):
Arguments:
padding: int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric padding
is applied to width and height.
is applied to height and width.
- If tuple of 3 ints:
interpreted as two different
symmetric padding values for height and width:
@ -2266,12 +2266,12 @@ class Cropping1D(Layer):
class Cropping2D(Layer):
"""Cropping layer for 2D input (e.g. picture).
It crops along spatial dimensions, i.e. width and height.
It crops along spatial dimensions, i.e. height and width.
Arguments:
cropping: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric cropping
is applied to width and height.
is applied to height and width.
- If tuple of 2 ints:
interpreted as two different
symmetric cropping values for height and width:

View File

@ -446,8 +446,8 @@ class Concatenate(_Merge):
class Dot(_Merge):
"""Layer that computes a dot product between samples in two tensors.
E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`,
the output will be a tensor of shape `(batch_size, 1)`
E.g. if applied to a list of two tensors `a` and `b` of shape
`(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
where each entry `i` will be the dot product between
`a[i]` and `b[i]`.

View File

@ -324,12 +324,12 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
class Sequence(object):
"""Base object for fitting to a sequence of data, such as a dataset.
Every `Sequence` must implements the `__getitem__` and the `__len__` methods.
Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
If you want to modify your dataset between epochs you may implement
`on_epoch_end`.
The method `__getitem__` should return a complete batch.
# Notes
Notes:
`Sequence` are a safer way to do multiprocessing. This structure guarantees
that the network will only train once

View File

@ -102,13 +102,12 @@ class HDF5Matrix(object):
idx = (self.start + key).tolist()
else:
raise IndexError
elif isinstance(key, list):
else:
# Assume list/iterable
if max(key) + self.start < self.end:
idx = [x + self.start for x in key]
else:
raise IndexError
else:
raise IndexError
if self.normalizer is not None:
return self.normalizer(self.data[idx])
else:

View File

@ -22,6 +22,7 @@ import os
import shutil
import numpy as np
import six
from tensorflow.python import keras
from tensorflow.python.platform import test
@ -95,6 +96,29 @@ class TestIOUtils(test.TestCase):
self.assertEqual(out_eval.shape, ())
self.assertGreater(out_eval, 0)
# test slicing for shortened array
self.assertEqual(len(x_train[0:]), len(x_train))
# test __getitem__ invalid use cases
with self.assertRaises(IndexError):
_ = x_train[1000]
with self.assertRaises(IndexError):
_ = x_train[1000: 1001]
with self.assertRaises(IndexError):
_ = x_train[[1000, 1001]]
with self.assertRaises(IndexError):
_ = x_train[six.moves.range(1000, 1001)]
with self.assertRaises(IndexError):
_ = x_train[np.array([1000])]
with self.assertRaises(TypeError):
_ = x_train[None]
# test normalizer
normalizer = lambda x: x + 1
normalized_x_train = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_data', start=0, end=150, normalizer=normalizer)
self.assertAllClose(normalized_x_train[0][0], x_train[0][0] + 1)
if __name__ == '__main__':
test.main()

View File

@ -196,7 +196,7 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
batch_size = shape[:1]
input_shape = shape[1:]
step = batch_size // parts
if i == num_gpus - 1:
if i == parts - 1:
size = batch_size - step * i
else:
size = step

View File

@ -77,7 +77,6 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
if isinstance(model, Sequential):
if not model.built:
model.build()
model = model.model
layers = model.layers
# Create graph nodes.