Roll forward of more robust Network Input conformance.

PiperOrigin-RevId: 297470882
Change-Id: Ic78cbd2448cf3c28282ebadcb4491096f0d38c3f
This commit is contained in:
Thomas O'Malley 2020-02-26 16:39:13 -08:00 committed by TensorFlower Gardener
parent 2bc72ad541
commit cddf574279
2 changed files with 72 additions and 25 deletions

View File

@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
@ -238,6 +239,9 @@ class Network(base_layer.Layer):
outputs = outputs[0]
self._nested_outputs = outputs
self._nested_inputs = inputs
self._nested_inputs_are_flat_list = (
isinstance(self._nested_inputs, (list, tuple)) and
not any(nest.is_sequence(t) for t in self._nested_inputs))
self.inputs = nest.flatten(inputs)
self.outputs = nest.flatten(outputs)
@ -814,40 +818,18 @@ class Network(base_layer.Layer):
# use masking, it does not interfere with regular behavior at all and you
# can ignore it.
if isinstance(inputs, dict) and isinstance(self._nested_inputs,
(list, tuple)):
# Backwards compat: Allows passing a dict to a Model constructed with a
# list. Matches dict keys to input names.
inputs = [
inputs[inp._keras_history.layer.name] for inp in self._nested_inputs
]
else:
inputs = nest.flatten(inputs)
inputs = self._flatten_to_reference_inputs(inputs)
if mask is None:
masks = [None for _ in range(len(inputs))]
else:
masks = nest.flatten(mask)
masks = self._flatten_to_reference_inputs(mask)
for input_t, mask in zip(inputs, masks):
input_t._keras_mask = mask
# Dictionary mapping reference tensors to computed tensors.
tensor_dict = {}
for x, y in zip(self.inputs, inputs):
# Set shape and dtype based on `keras.Input`s.
if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
try:
y.set_shape(y.shape.merge_with(x.shape))
except ValueError:
logging.warning(
'Model was constructed with shape {} for input {}, but it was '
're-called on a Tensor with incompatible shape {}.'
.format(x, x.shape, y.shape))
if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)):
y = math_ops.cast(y, dtype=x.dtype)
y = self._conform_to_reference_input(y, ref_input=x)
x_id = str(id(x))
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
@ -925,6 +907,53 @@ class Network(base_layer.Layer):
output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
return output_tensors
def _flatten_to_reference_inputs(self, tensors):
"""Maps `tensors` to their respective `keras.Input`."""
if self._nested_inputs_are_flat_list and isinstance(tensors, dict):
# Backwards compat: Allows passing a dict to a Model constructed with a
# list. Matches dict keys to input names.
tensors = [
tensors[inp._keras_history.layer.name] for inp in self._nested_inputs
]
else:
# Otherwise both self.inputs and tensors will be flattened in same order.
tensors = nest.flatten(tensors)
return tensors
def _conform_to_reference_input(self, tensor, ref_input):
"""Set shape and dtype based on `keras.Input`s."""
# Shape handling (only for non-CompositeTensors).
if isinstance(tensor, ops.Tensor) and isinstance(ref_input, ops.Tensor):
# Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the
# shape specified by the `keras.Input`.
if tensor.shape.rank is not None and ref_input.shape.rank is not None:
should_squeeze_last_dim = (
tensor.shape.rank == ref_input.shape.rank + 1 and
tensor.shape[-1] == 1)
should_expand_last_dim = (
tensor.shape.rank == ref_input.shape.rank - 1 and
ref_input.shape[-1] == 1)
if should_squeeze_last_dim:
tensor = array_ops.squeeze_v2(tensor, axis=-1)
elif should_expand_last_dim:
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
# Add shape hints to Tensors that might have None shape dims but have
# shapes defined by the `keras.Input`.
try:
tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
except ValueError:
logging.warning(
'Model was constructed with shape {} for input {}, but it was '
'called on an input with incompatible shape {}.'.format(
ref_input.shape, ref_input, tensor.shape))
# Dtype handling.
if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)):
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
return tensor
def get_config(self):
if not self._is_graph_network:
raise NotImplementedError

View File

@ -1879,6 +1879,24 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
for i in range(999, 1024):
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
def test_2d_inputs_squeezed_to_1d(self):
input_1d = input_layer_lib.Input(shape=())
outputs = input_1d * 2.
net = network_lib.Network(input_1d, outputs)
x = np.ones((10, 1))
y = net(x)
self.assertEqual(y.shape.rank, 1)
def test_1d_inputs_expanded_to_2d(self):
input_1d = input_layer_lib.Input(shape=(1,))
outputs = input_1d * 2.
net = network_lib.Network(input_1d, outputs)
x = np.ones((10,))
y = net(x)
self.assertEqual(y.shape.rank, 2)
if __name__ == '__main__':
test.main()