diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index e13ab8f0b92..148df242e48 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging @@ -238,6 +239,9 @@ class Network(base_layer.Layer): outputs = outputs[0] self._nested_outputs = outputs self._nested_inputs = inputs + self._nested_inputs_are_flat_list = ( + isinstance(self._nested_inputs, (list, tuple)) and + not any(nest.is_sequence(t) for t in self._nested_inputs)) self.inputs = nest.flatten(inputs) self.outputs = nest.flatten(outputs) @@ -814,40 +818,18 @@ class Network(base_layer.Layer): # use masking, it does not interfere with regular behavior at all and you # can ignore it. - if isinstance(inputs, dict) and isinstance(self._nested_inputs, - (list, tuple)): - # Backwards compat: Allows passing a dict to a Model constructed with a - # list. Matches dict keys to input names. - inputs = [ - inputs[inp._keras_history.layer.name] for inp in self._nested_inputs - ] - else: - inputs = nest.flatten(inputs) - + inputs = self._flatten_to_reference_inputs(inputs) if mask is None: masks = [None for _ in range(len(inputs))] else: - masks = nest.flatten(mask) - + masks = self._flatten_to_reference_inputs(mask) for input_t, mask in zip(inputs, masks): input_t._keras_mask = mask # Dictionary mapping reference tensors to computed tensors. tensor_dict = {} - for x, y in zip(self.inputs, inputs): - # Set shape and dtype based on `keras.Input`s. - if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor): - try: - y.set_shape(y.shape.merge_with(x.shape)) - except ValueError: - logging.warning( - 'Model was constructed with shape {} for input {}, but it was ' - 're-called on a Tensor with incompatible shape {}.' - .format(x, x.shape, y.shape)) - if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)): - y = math_ops.cast(y, dtype=x.dtype) - + y = self._conform_to_reference_input(y, ref_input=x) x_id = str(id(x)) tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] @@ -925,6 +907,53 @@ class Network(base_layer.Layer): output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors) return output_tensors + def _flatten_to_reference_inputs(self, tensors): + """Maps `tensors` to their respective `keras.Input`.""" + if self._nested_inputs_are_flat_list and isinstance(tensors, dict): + # Backwards compat: Allows passing a dict to a Model constructed with a + # list. Matches dict keys to input names. + tensors = [ + tensors[inp._keras_history.layer.name] for inp in self._nested_inputs + ] + else: + # Otherwise both self.inputs and tensors will be flattened in same order. + tensors = nest.flatten(tensors) + return tensors + + def _conform_to_reference_input(self, tensor, ref_input): + """Set shape and dtype based on `keras.Input`s.""" + # Shape handling (only for non-CompositeTensors). + if isinstance(tensor, ops.Tensor) and isinstance(ref_input, ops.Tensor): + # Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the + # shape specified by the `keras.Input`. + if tensor.shape.rank is not None and ref_input.shape.rank is not None: + should_squeeze_last_dim = ( + tensor.shape.rank == ref_input.shape.rank + 1 and + tensor.shape[-1] == 1) + should_expand_last_dim = ( + tensor.shape.rank == ref_input.shape.rank - 1 and + ref_input.shape[-1] == 1) + if should_squeeze_last_dim: + tensor = array_ops.squeeze_v2(tensor, axis=-1) + elif should_expand_last_dim: + tensor = array_ops.expand_dims_v2(tensor, axis=-1) + + # Add shape hints to Tensors that might have None shape dims but have + # shapes defined by the `keras.Input`. + try: + tensor.set_shape(tensor.shape.merge_with(ref_input.shape)) + except ValueError: + logging.warning( + 'Model was constructed with shape {} for input {}, but it was ' + 'called on an input with incompatible shape {}.'.format( + ref_input.shape, ref_input, tensor.shape)) + + # Dtype handling. + if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)): + tensor = math_ops.cast(tensor, dtype=ref_input.dtype) + + return tensor + def get_config(self): if not self._is_graph_network: raise NotImplementedError diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index 17f08889936..d890cc118ae 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -1879,6 +1879,24 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): for i in range(999, 1024): self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2)) + def test_2d_inputs_squeezed_to_1d(self): + input_1d = input_layer_lib.Input(shape=()) + outputs = input_1d * 2. + net = network_lib.Network(input_1d, outputs) + + x = np.ones((10, 1)) + y = net(x) + self.assertEqual(y.shape.rank, 1) + + def test_1d_inputs_expanded_to_2d(self): + input_1d = input_layer_lib.Input(shape=(1,)) + outputs = input_1d * 2. + net = network_lib.Network(input_1d, outputs) + + x = np.ones((10,)) + y = net(x) + self.assertEqual(y.shape.rank, 2) + if __name__ == '__main__': test.main()