Merge branch 'master' of sso://tensorflow/staging

This commit is contained in:
TensorFlower Gardener 2017-10-20 15:39:00 -07:00
commit f5b14e496f
6 changed files with 262 additions and 110 deletions

View File

@ -87,6 +87,40 @@ the Dataset API is still strongly recommended. Try to avoid the following:
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
```
#### Fused decode and crop
If inputs are JPEG images that also require cropping, use fused
@{tf.image.decode_and_crop_jpeg} to speed up preprocessing.
`tf.image.decode_and_crop_jpeg` only decodes the part of
the image within the crop window. This significantly speeds up the process if
the crop window is much smaller than the full image. For imagenet data, this
approach could speed up the input pipeline by up to 30%.
Example Usage:
```python
def _image_preprocess_fn(image_buffer):
# image_buffer 1-D string Tensor representing the raw JPEG image buffer.
# Extract image shape from raw JPEG image buffer.
image_shape = tf.image.extract_jpeg_shape(image_buffer)
# Get a crop window with distorted bounding box.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
image_shape, ...)
bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
# Decode and crop image.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
cropped_image = tf.image.decode_and_crop_jpeg(image, crop_window)
```
`tf.image.decode_and_crop_jpeg` is available on all platforms. There is no speed
up on Windows due to the use of `libjpeg` vs. `libjpeg-turbo` on other
platforms.
#### Use large files
Reading large numbers of small files significantly impacts I/O performance.

View File

@ -776,7 +776,7 @@ class Network(tf_base_layers.Network, Layer):
if cache_key in self._output_mask_cache:
return self._output_mask_cache[cache_key]
else:
_, output_masks, _ = self._run_internal_graph(inputs, masks)
_, output_masks = self._run_internal_graph(inputs, masks)
return output_masks
def get_config(self):

View File

@ -192,10 +192,12 @@ class KerasIntegrationTest(test.TestCase):
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
self.assertEqual(len(model.losses), 2)
self.assertEqual(len(model.updates), 2)
history = model.fit(x_train, y_train, epochs=10, batch_size=16,
validation_data=(x_test, y_test),
verbose=2)
self.assertGreater(history.history['val_acc'][-1], 0.85)
self.assertGreater(history.history['val_acc'][-1], 0.84)
def test_vector_classification_shared_model(self):
# Test that functional models that feature internal updates

View File

@ -420,6 +420,8 @@ class Sequential(Model):
# Used by Layer base class.
self._dtype = None
self._activity_regularizer = None
self._per_input_losses = {}
self._per_input_updates = {}
# The following properties are not actually used by Keras;
# they exist for compatibility with TF's variable scoping mechanism.

View File

@ -508,6 +508,7 @@ class Layer(object):
input_list = nest.flatten(inputs)
in_graph_mode = context.in_graph_mode()
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created.
if in_graph_mode:
@ -515,6 +516,7 @@ class Layer(object):
ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
if in_graph_mode or in_deferred_mode:
user_kwargs = copy.copy(kwargs)
# Handle Keras mask propagation from previous layer to current layer.
@ -553,6 +555,7 @@ class Layer(object):
raise ValueError('activity_regularizer currently unsupported in '
'Eager mode. Found an activity_regularizer in '
'%s(%s).' % (self.__class__.__name__, self))
if not in_graph_mode and not in_deferred_mode:
# TODO(agarwal): support _keras_history in Eager mode.
for x in input_list:
if hasattr(x, '_keras_history'):
@ -581,13 +584,26 @@ class Layer(object):
if call_has_scope_arg:
kwargs['scope'] = scope
# Check input assumptions set after layer building, e.g. input shape.
if in_graph_mode:
if in_graph_mode or in_deferred_mode:
self._assert_input_compatibility(inputs)
outputs = self.call(inputs, *args, **kwargs)
if not in_deferred_mode:
outputs = self.call(inputs, *args, **kwargs)
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None.')
else:
# Deferred mode behavior: use `_compute_output_shape` to
# infer the number of outputs of the layer and their shapes.
output_shapes = self._compute_output_shape(input_shapes)
output_shapes = nest.flatten(output_shapes)
outputs = [
# TODO(fchollet): name the deferred tensors?
_DeferredTensor(shape=shape, dtype=self._dtype)
for shape in output_shapes
]
if len(outputs) == 1:
outputs = outputs[0]
if in_graph_mode:
# Apply activity regularization.
@ -600,6 +616,8 @@ class Layer(object):
activity_regularization = self._activity_regularizer(output)
self.add_loss(activity_regularization)
if not in_deferred_mode:
# TODO(fchollet): consider how masking will work with deferred mode.
# Handle mask computation and propagation to the next layer.
if hasattr(self, 'compute_mask'):
output_mask = self.compute_mask(inputs, previous_mask)
@ -631,14 +649,16 @@ class Layer(object):
else:
outputs = output_ls_copy
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
if in_deferred_mode or in_graph_mode:
if _have_all_keras_metadata(inputs):
# Add an inbound node to the layer, so it can keep track of this call.
# This updates the layer history of the output tensor(s).
self._add_inbound_node(
input_tensors=inputs, output_tensors=outputs, arguments=user_kwargs)
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
self.built = True
return outputs
@ -692,7 +712,6 @@ class Layer(object):
arguments: dictionary of keyword arguments that were passed to the
`call` method of the layer at the call that created the node.
"""
assert context.in_graph_mode()
input_tensors = nest.flatten(input_tensors)
output_tensors = nest.flatten(output_tensors)
@ -1251,6 +1270,34 @@ class Node(object):
}
class _DeferredTensor(object):
"""Tensor-like object used to build graphs of layers in Eager mode.
When calling a layer on a DeferredTensor, the layer will not perform any
computation and will simply perfom shape inference to return new
DeferredTensors with appropriate shape information. Thus DeferredTensor
behaves like a graph-mode Tensor when manipulated by layers.
"""
def __init__(self, shape, dtype, name=None):
self.shape = tensor_shape.TensorShape(shape)
self.dtype = dtypes.as_dtype(dtype)
self.name = name
def get_shape(self):
return self.shape
def __str__(self):
return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
self.get_shape(),
self.dtype.name)
def __repr__(self):
return "<_DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
self.get_shape(),
self.dtype.name)
class InputLayer(Layer):
"""Layer to be used as an entry point into a Network (a graph of layers).
@ -1283,8 +1330,6 @@ class InputLayer(Layer):
input_tensor=None,
sparse=False,
name=None):
if context.in_eager_mode():
raise RuntimeError('InputLayer not supported in Eager mode.')
super(InputLayer, self).__init__(dtype=dtype, name=name)
self.built = True
self.sparse = sparse
@ -1299,6 +1344,14 @@ class InputLayer(Layer):
else:
batch_input_shape = None
if context.in_eager_mode():
# In eager mode, create a temporary placeholder to call the layer on.
input_tensor = _DeferredTensor(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
else:
# In graph mode, create a graph placeholder to call the layer on.
if sparse:
input_tensor = array_ops.sparse_placeholder(
shape=batch_input_shape,
@ -1375,8 +1428,6 @@ def Input( # pylint: disable=invalid-name
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
raise RuntimeError('Input not supported in Eager mode.')
input_layer = InputLayer(
input_shape=shape,
batch_size=batch_size,
@ -1440,9 +1491,10 @@ class Network(Layer):
"""
def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called
# TODO(agarwal): Make Network work in Eager mode.
if context.in_eager_mode():
raise RuntimeError('Network not supported in Eager mode.')
# TODO(fchollet): check that all inputs and outputs are DeferredTensors.
pass
# Set layer name and scope
if isinstance(name, vs.VariableScope):
base_name = name.name
@ -1919,16 +1971,17 @@ class Network(Layer):
masks = [None for _ in range(len(inputs))]
else:
masks = nest.flatten(mask)
if context.in_graph_mode():
# Try to retrieve cached outputs if the layer has already been called
# on these exact inputs.
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
if cache_key in self._output_tensor_cache:
# Cache hit.
return self._output_tensor_cache[cache_key]
else:
# Cache miss: actually apply the network graph to the new inputs.
output_tensors, _, _ = self._run_internal_graph(inputs, masks)
return output_tensors
# Actually apply the network graph to the new inputs.
outputs, _ = self._run_internal_graph(inputs, masks)
return outputs
def _compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
@ -2091,6 +2144,7 @@ class Network(Layer):
if 'mask' in estimator_util.fn_args(layer.call):
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
output_tensors = nest.flatten(
layer.call(computed_tensor, **kwargs))
if hasattr(layer, 'compute_mask'):
@ -2121,6 +2175,7 @@ class Network(Layer):
]
layer.add_loss(regularization_losses, computed_tensors)
if context.in_graph_mode():
# Update model updates and losses:
# Keep track of updates that depend on the inputs
# (e.g. BN updates).
@ -2149,31 +2204,26 @@ class Network(Layer):
output_tensors.append(tensor)
output_masks.append(mask)
if len(output_tensors) == 1:
output_tensors = output_tensors[0]
if output_shapes is not None:
output_shapes = output_shapes[0]
if output_masks is not None:
output_masks = output_masks[0]
if context.in_graph_mode():
# Update cache;
# keys are based on ids on input tensors and inputs masks.
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
if len(output_tensors) == 1:
output_tensors = output_tensors[0]
self._output_tensor_cache[cache_key] = output_tensors
else:
self._output_tensor_cache[cache_key] = output_tensors
if len(output_masks) == 1:
output_masks = output_masks[0]
if output_masks is not None:
self._output_mask_cache[cache_key] = output_masks
else:
self._output_mask_cache[cache_key] = output_masks
if output_shapes is not None:
input_shapes = [_static_shape(x) for x in inputs]
cache_key = _object_list_uid(input_shapes)
if len(output_shapes) == 1:
output_shapes = output_shapes[0]
self._output_shape_cache[cache_key] = output_shapes
else:
self._output_shape_cache[cache_key] = output_shapes
return output_tensors, output_masks, output_shapes
return output_tensors, output_masks
def _is_tensor_or_tensor_list(v):

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import copy
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -41,13 +43,13 @@ class BaseLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testLayerProperties(self):
layer = base_layers.Layer(name='my_layer')
self.assertListEqual(layer.variables, [])
self.assertListEqual(layer.trainable_variables, [])
self.assertListEqual(layer.non_trainable_variables, [])
self.assertEqual(layer.variables, [])
self.assertEqual(layer.trainable_variables, [])
self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
# updates, losses only suppported in GRAPH mode
self.assertListEqual(layer.updates, [])
self.assertListEqual(layer.losses, [])
self.assertEqual(layer.updates, [])
self.assertEqual(layer.losses, [])
self.assertEqual(layer.built, False)
layer = base_layers.Layer(name='my_layer', trainable=False)
self.assertEqual(layer.trainable, False)
@ -60,11 +62,11 @@ class BaseLayerTest(test.TestCase):
variable = layer.add_variable(
'my_var', [2, 2], initializer=init_ops.zeros_initializer())
self.assertEqual(variable.name, 'my_layer/my_var:0')
self.assertListEqual(layer.variables, [variable])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [])
self.assertEqual(layer.variables, [variable])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
self.assertListEqual(
self.assertEqual(
layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
@ -74,9 +76,9 @@ class BaseLayerTest(test.TestCase):
'non_trainable_var', [2, 2],
initializer=init_ops.zeros_initializer(),
trainable=False)
self.assertListEqual(layer.variables, [variable, variable_2])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [variable_2])
self.assertEqual(layer.variables, [variable, variable_2])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [variable_2])
if context.in_graph_mode():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
@ -105,7 +107,7 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1)
layer.apply(inputs)
layer.apply(inputs)
self.assertListEqual([v.name for v in layer.variables],
self.assertEqual([v.name for v in layer.variables],
['my_layer/my_var:0'])
# Creating a layer with no scope leads to lazy construction of
@ -120,7 +122,7 @@ class BaseLayerTest(test.TestCase):
# The variables were created outside of the Layer, and
# reuse=True, so the Layer does not own them and they are not
# stored in its collection.
self.assertListEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')
# Creating a layer with no scope leads to lazy construction of
@ -135,7 +137,7 @@ class BaseLayerTest(test.TestCase):
# The variables were created outside of the Layer, and
# reuse=True, so the Layer does not own them and they are not
# stored in its collection.
self.assertListEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer._scope.name, 'new_scope')
# Checking for graph equality is only done in GRAPH mode.
@ -183,13 +185,13 @@ class BaseLayerTest(test.TestCase):
outputs = layer.apply(inputs)
self.assertEqual(layer.built, True)
self.assertEqual(outputs.op.name, 'my_layer/add')
self.assertListEqual([v.name
self.assertEqual([v.name
for v in layer.variables], ['my_layer/my_var:0'])
with self.assertRaisesRegexp(ValueError,
'my_layer/this_will_break_on_second_call'):
layer.apply(inputs)
# The list of variables hasn't changed.
self.assertListEqual([v.name
self.assertEqual([v.name
for v in layer.variables], ['my_layer/my_var:0'])
@test_util.run_in_graph_and_eager_modes()
@ -435,8 +437,8 @@ class BaseLayerTest(test.TestCase):
dense_layer.add_update(0, inputs=a)
dense_layer.add_update(1, inputs=None)
self.assertListEqual(dense_layer.get_updates_for(a), [0])
self.assertListEqual(dense_layer.get_updates_for(None), [1])
self.assertEqual(dense_layer.get_updates_for(a), [0])
self.assertEqual(dense_layer.get_updates_for(None), [1])
def test_get_losses_for(self):
a = base_layers.Input(shape=(2,))
@ -444,8 +446,8 @@ class BaseLayerTest(test.TestCase):
dense_layer.add_loss(0, inputs=a)
dense_layer.add_loss(1, inputs=None)
self.assertListEqual(dense_layer.get_losses_for(a), [0])
self.assertListEqual(dense_layer.get_losses_for(None), [1])
self.assertEqual(dense_layer.get_losses_for(a), [0])
self.assertEqual(dense_layer.get_losses_for(None), [1])
def testTopologicalAttributes(self):
# test layer attributes / methods related to cross-layer connectivity.
@ -612,7 +614,7 @@ class NetworkTest(test.TestCase):
a = base_layers.Input(shape=(32,), name='input_a')
b = base_layers.Input(shape=(32,), name='input_b')
self.assertListEqual(a.get_shape().as_list(), [None, 32])
self.assertEqual(a.get_shape().as_list(), [None, 32])
a_layer, a_node_index, a_tensor_index = a._keras_history
b_layer, _, _ = b._keras_history
self.assertEqual(len(a_layer._inbound_nodes), 1)
@ -620,11 +622,11 @@ class NetworkTest(test.TestCase):
node = a_layer._inbound_nodes[a_node_index]
self.assertEqual(node.outbound_layer, a_layer)
self.assertListEqual(node.inbound_layers, [])
self.assertListEqual(node.input_tensors, [a])
self.assertListEqual(node.input_shapes, [(None, 32)])
self.assertListEqual(node.output_tensors, [a])
self.assertListEqual(node.output_shapes, [(None, 32)])
self.assertEqual(node.inbound_layers, [])
self.assertEqual(node.input_tensors, [a])
self.assertEqual(node.input_shapes, [(None, 32)])
self.assertEqual(node.output_tensors, [a])
self.assertEqual(node.output_shapes, [(None, 32)])
dense = core_layers.Dense(16, name='dense_1')
dense(a)
@ -632,12 +634,12 @@ class NetworkTest(test.TestCase):
self.assertEqual(len(dense._inbound_nodes), 2)
self.assertEqual(len(dense._outbound_nodes), 0)
self.assertListEqual(dense._inbound_nodes[0].inbound_layers, [a_layer])
self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer])
self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense)
self.assertListEqual(dense._inbound_nodes[1].inbound_layers, [b_layer])
self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer])
self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense)
self.assertListEqual(dense._inbound_nodes[0].input_tensors, [a])
self.assertListEqual(dense._inbound_nodes[1].input_tensors, [b])
self.assertEqual(dense._inbound_nodes[0].input_tensors, [a])
self.assertEqual(dense._inbound_nodes[1].input_tensors, [b])
# Test config
config_0 = dense._inbound_nodes[0].get_config()
@ -889,5 +891,67 @@ class NetworkTest(test.TestCase):
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
class DeferredModeTest(test.TestCase):
def testDeferredTensorAttributes(self):
x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x')
self.assertEqual(str(x),
'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)')
self.assertEqual(repr(x),
'<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>')
@test_util.run_in_graph_and_eager_modes()
def testSimpleNetworkBuilding(self):
inputs = base_layers.Input(shape=(32,))
if context.in_eager_mode():
self.assertIsInstance(inputs, base_layers._DeferredTensor)
self.assertEqual(inputs.dtype.name, 'float32')
self.assertEqual(inputs.shape.as_list(), [None, 32])
x = core_layers.Dense(2)(inputs)
if context.in_eager_mode():
self.assertIsInstance(x, base_layers._DeferredTensor)
self.assertEqual(x.dtype.name, 'float32')
self.assertEqual(x.shape.as_list(), [None, 2])
outputs = core_layers.Dense(4)(x)
network = base_layers.Network(inputs, outputs)
self.assertIsInstance(network, base_layers.Network)
if context.in_eager_mode():
# It should be possible to call such a network on EagerTensors.
inputs = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
outputs = network(inputs)
self.assertEqual(outputs.shape.as_list(), [10, 4])
@test_util.run_in_graph_and_eager_modes()
def testMultiIONetworkbuilding(self):
input_a = base_layers.Input(shape=(32,))
input_b = base_layers.Input(shape=(16,))
a = core_layers.Dense(16)(input_a)
class AddLayer(base_layers.Layer):
def call(self, inputs):
return inputs[0] + inputs[1]
def _compute_output_shape(self, input_shape):
return input_shape[0]
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
c = core_layers.Dense(2)(c)
network = base_layers.Network([input_a, input_b], [a, c])
if context.in_eager_mode():
a_val = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
b_val = constant_op.constant(
np.random.random((10, 16)).astype('float32'))
outputs = network([a_val, b_val])
self.assertEqual(len(outputs), 2)
self.assertEqual(outputs[0].shape.as_list(), [10, 16])
self.assertEqual(outputs[1].shape.as_list(), [10, 2])
if __name__ == '__main__':
test.main()