Roll forward of more robust Network Input conformance.
PiperOrigin-RevId: 297470882 Change-Id: Ic78cbd2448cf3c28282ebadcb4491096f0d38c3f
This commit is contained in:
parent
2bc72ad541
commit
cddf574279
@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils
|
|||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -238,6 +239,9 @@ class Network(base_layer.Layer):
|
|||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
self._nested_outputs = outputs
|
self._nested_outputs = outputs
|
||||||
self._nested_inputs = inputs
|
self._nested_inputs = inputs
|
||||||
|
self._nested_inputs_are_flat_list = (
|
||||||
|
isinstance(self._nested_inputs, (list, tuple)) and
|
||||||
|
not any(nest.is_sequence(t) for t in self._nested_inputs))
|
||||||
self.inputs = nest.flatten(inputs)
|
self.inputs = nest.flatten(inputs)
|
||||||
self.outputs = nest.flatten(outputs)
|
self.outputs = nest.flatten(outputs)
|
||||||
|
|
||||||
@ -814,40 +818,18 @@ class Network(base_layer.Layer):
|
|||||||
# use masking, it does not interfere with regular behavior at all and you
|
# use masking, it does not interfere with regular behavior at all and you
|
||||||
# can ignore it.
|
# can ignore it.
|
||||||
|
|
||||||
if isinstance(inputs, dict) and isinstance(self._nested_inputs,
|
inputs = self._flatten_to_reference_inputs(inputs)
|
||||||
(list, tuple)):
|
|
||||||
# Backwards compat: Allows passing a dict to a Model constructed with a
|
|
||||||
# list. Matches dict keys to input names.
|
|
||||||
inputs = [
|
|
||||||
inputs[inp._keras_history.layer.name] for inp in self._nested_inputs
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
inputs = nest.flatten(inputs)
|
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
masks = [None for _ in range(len(inputs))]
|
masks = [None for _ in range(len(inputs))]
|
||||||
else:
|
else:
|
||||||
masks = nest.flatten(mask)
|
masks = self._flatten_to_reference_inputs(mask)
|
||||||
|
|
||||||
for input_t, mask in zip(inputs, masks):
|
for input_t, mask in zip(inputs, masks):
|
||||||
input_t._keras_mask = mask
|
input_t._keras_mask = mask
|
||||||
|
|
||||||
# Dictionary mapping reference tensors to computed tensors.
|
# Dictionary mapping reference tensors to computed tensors.
|
||||||
tensor_dict = {}
|
tensor_dict = {}
|
||||||
|
|
||||||
for x, y in zip(self.inputs, inputs):
|
for x, y in zip(self.inputs, inputs):
|
||||||
# Set shape and dtype based on `keras.Input`s.
|
y = self._conform_to_reference_input(y, ref_input=x)
|
||||||
if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
|
|
||||||
try:
|
|
||||||
y.set_shape(y.shape.merge_with(x.shape))
|
|
||||||
except ValueError:
|
|
||||||
logging.warning(
|
|
||||||
'Model was constructed with shape {} for input {}, but it was '
|
|
||||||
're-called on a Tensor with incompatible shape {}.'
|
|
||||||
.format(x, x.shape, y.shape))
|
|
||||||
if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)):
|
|
||||||
y = math_ops.cast(y, dtype=x.dtype)
|
|
||||||
|
|
||||||
x_id = str(id(x))
|
x_id = str(id(x))
|
||||||
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
|
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
|
||||||
|
|
||||||
@ -925,6 +907,53 @@ class Network(base_layer.Layer):
|
|||||||
output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
|
output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
|
def _flatten_to_reference_inputs(self, tensors):
|
||||||
|
"""Maps `tensors` to their respective `keras.Input`."""
|
||||||
|
if self._nested_inputs_are_flat_list and isinstance(tensors, dict):
|
||||||
|
# Backwards compat: Allows passing a dict to a Model constructed with a
|
||||||
|
# list. Matches dict keys to input names.
|
||||||
|
tensors = [
|
||||||
|
tensors[inp._keras_history.layer.name] for inp in self._nested_inputs
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Otherwise both self.inputs and tensors will be flattened in same order.
|
||||||
|
tensors = nest.flatten(tensors)
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
def _conform_to_reference_input(self, tensor, ref_input):
|
||||||
|
"""Set shape and dtype based on `keras.Input`s."""
|
||||||
|
# Shape handling (only for non-CompositeTensors).
|
||||||
|
if isinstance(tensor, ops.Tensor) and isinstance(ref_input, ops.Tensor):
|
||||||
|
# Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the
|
||||||
|
# shape specified by the `keras.Input`.
|
||||||
|
if tensor.shape.rank is not None and ref_input.shape.rank is not None:
|
||||||
|
should_squeeze_last_dim = (
|
||||||
|
tensor.shape.rank == ref_input.shape.rank + 1 and
|
||||||
|
tensor.shape[-1] == 1)
|
||||||
|
should_expand_last_dim = (
|
||||||
|
tensor.shape.rank == ref_input.shape.rank - 1 and
|
||||||
|
ref_input.shape[-1] == 1)
|
||||||
|
if should_squeeze_last_dim:
|
||||||
|
tensor = array_ops.squeeze_v2(tensor, axis=-1)
|
||||||
|
elif should_expand_last_dim:
|
||||||
|
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
|
||||||
|
|
||||||
|
# Add shape hints to Tensors that might have None shape dims but have
|
||||||
|
# shapes defined by the `keras.Input`.
|
||||||
|
try:
|
||||||
|
tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
|
||||||
|
except ValueError:
|
||||||
|
logging.warning(
|
||||||
|
'Model was constructed with shape {} for input {}, but it was '
|
||||||
|
'called on an input with incompatible shape {}.'.format(
|
||||||
|
ref_input.shape, ref_input, tensor.shape))
|
||||||
|
|
||||||
|
# Dtype handling.
|
||||||
|
if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)):
|
||||||
|
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
if not self._is_graph_network:
|
if not self._is_graph_network:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1879,6 +1879,24 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||||||
for i in range(999, 1024):
|
for i in range(999, 1024):
|
||||||
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
|
self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
|
||||||
|
|
||||||
|
def test_2d_inputs_squeezed_to_1d(self):
|
||||||
|
input_1d = input_layer_lib.Input(shape=())
|
||||||
|
outputs = input_1d * 2.
|
||||||
|
net = network_lib.Network(input_1d, outputs)
|
||||||
|
|
||||||
|
x = np.ones((10, 1))
|
||||||
|
y = net(x)
|
||||||
|
self.assertEqual(y.shape.rank, 1)
|
||||||
|
|
||||||
|
def test_1d_inputs_expanded_to_2d(self):
|
||||||
|
input_1d = input_layer_lib.Input(shape=(1,))
|
||||||
|
outputs = input_1d * 2.
|
||||||
|
net = network_lib.Network(input_1d, outputs)
|
||||||
|
|
||||||
|
x = np.ones((10,))
|
||||||
|
y = net(x)
|
||||||
|
self.assertEqual(y.shape.rank, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user