Roll forward of more robust Network Input conformance.
PiperOrigin-RevId: 297470882 Change-Id: Ic78cbd2448cf3c28282ebadcb4491096f0d38c3f
This commit is contained in:
parent
2bc72ad541
commit
cddf574279
@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -238,6 +239,9 @@ class Network(base_layer.Layer):
|
||||
outputs = outputs[0]
|
||||
self._nested_outputs = outputs
|
||||
self._nested_inputs = inputs
|
||||
self._nested_inputs_are_flat_list = (
|
||||
isinstance(self._nested_inputs, (list, tuple)) and
|
||||
not any(nest.is_sequence(t) for t in self._nested_inputs))
|
||||
self.inputs = nest.flatten(inputs)
|
||||
self.outputs = nest.flatten(outputs)
|
||||
|
||||
@ -814,40 +818,18 @@ class Network(base_layer.Layer):
|
||||
# use masking, it does not interfere with regular behavior at all and you
|
||||
# can ignore it.
|
||||
|
||||
if isinstance(inputs, dict) and isinstance(self._nested_inputs,
|
||||
(list, tuple)):
|
||||
# Backwards compat: Allows passing a dict to a Model constructed with a
|
||||
# list. Matches dict keys to input names.
|
||||
inputs = [
|
||||
inputs[inp._keras_history.layer.name] for inp in self._nested_inputs
|
||||
]
|
||||
else:
|
||||
inputs = nest.flatten(inputs)
|
||||
|
||||
inputs = self._flatten_to_reference_inputs(inputs)
|
||||
if mask is None:
|
||||
masks = [None for _ in range(len(inputs))]
|
||||
else:
|
||||
masks = nest.flatten(mask)
|
||||
|
||||
masks = self._flatten_to_reference_inputs(mask)
|
||||
for input_t, mask in zip(inputs, masks):
|
||||
input_t._keras_mask = mask
|
||||
|
||||
# Dictionary mapping reference tensors to computed tensors.
|
||||
tensor_dict = {}
|
||||
|
||||
for x, y in zip(self.inputs, inputs):
|
||||
# Set shape and dtype based on `keras.Input`s.
|
||||
if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
|
||||
try:
|
||||
y.set_shape(y.shape.merge_with(x.shape))
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
'Model was constructed with shape {} for input {}, but it was '
|
||||
're-called on a Tensor with incompatible shape {}.'
|
||||
.format(x, x.shape, y.shape))
|
||||
if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)):
|
||||
y = math_ops.cast(y, dtype=x.dtype)
|
||||
|
||||
y = self._conform_to_reference_input(y, ref_input=x)
|
||||
x_id = str(id(x))
|
||||
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
|
||||
|
||||
@ -925,6 +907,53 @@ class Network(base_layer.Layer):
|
||||
output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
|
||||
return output_tensors
|
||||
|
||||
def _flatten_to_reference_inputs(self, tensors):
|
||||
"""Maps `tensors` to their respective `keras.Input`."""
|
||||
if self._nested_inputs_are_flat_list and isinstance(tensors, dict):
|
||||
# Backwards compat: Allows passing a dict to a Model constructed with a
|
||||
# list. Matches dict keys to input names.
|
||||
tensors = [
|
||||
tensors[inp._keras_history.layer.name] for inp in self._nested_inputs
|
||||
]
|
||||
else:
|
||||
# Otherwise both self.inputs and tensors will be flattened in same order.
|
||||
tensors = nest.flatten(tensors)
|
||||
return tensors
|
||||
|
||||
def _conform_to_reference_input(self, tensor, ref_input):
|
||||
"""Set shape and dtype based on `keras.Input`s."""
|
||||
# Shape handling (only for non-CompositeTensors).
|
||||
if isinstance(tensor, ops.Tensor) and isinstance(ref_input, ops.Tensor):
|
||||
# Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the
|
||||
# shape specified by the `keras.Input`.
|
||||
if tensor.shape.rank is not None and ref_input.shape.rank is not None:
|
||||
should_squeeze_last_dim = (
|
||||
tensor.shape.rank == ref_input.shape.rank + 1 and
|
||||
tensor.shape[-1] == 1)
|
||||
should_expand_last_dim = (
|
||||
tensor.shape.rank == ref_input.shape.rank - 1 and
|
||||
ref_input.shape[-1] == 1)
|
||||
if should_squeeze_last_dim:
|
||||
tensor = array_ops.squeeze_v2(tensor, axis=-1)
|
||||
elif should_expand_last_dim:
|
||||
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
|
||||
|
||||
# Add shape hints to Tensors that might have None shape dims but have
|
||||
# shapes defined by the `keras.Input`.
|
||||
try:
|
||||
tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
'Model was constructed with shape {} for input {}, but it was '
|
||||
'called on an input with incompatible shape {}.'.format(
|
||||
ref_input.shape, ref_input, tensor.shape))
|
||||
|
||||
# Dtype handling.
|
||||
if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)):
|
||||
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
def get_config(self):
|
||||
if not self._is_graph_network:
|
||||
raise NotImplementedError
|
||||
|
@ -1879,6 +1879,24 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
for i in range(999, 1024):
|
||||
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
|
||||
|
||||
def test_2d_inputs_squeezed_to_1d(self):
|
||||
input_1d = input_layer_lib.Input(shape=())
|
||||
outputs = input_1d * 2.
|
||||
net = network_lib.Network(input_1d, outputs)
|
||||
|
||||
x = np.ones((10, 1))
|
||||
y = net(x)
|
||||
self.assertEqual(y.shape.rank, 1)
|
||||
|
||||
def test_1d_inputs_expanded_to_2d(self):
|
||||
input_1d = input_layer_lib.Input(shape=(1,))
|
||||
outputs = input_1d * 2.
|
||||
net = network_lib.Network(input_1d, outputs)
|
||||
|
||||
x = np.ones((10,))
|
||||
y = net(x)
|
||||
self.assertEqual(y.shape.rank, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user