Fix connectivity metadata missing issue on functional model.
PiperOrigin-RevId: 315381292 Change-Id: I3a7e61a0afbda0ae1984a4152dd297e350b29775
This commit is contained in:
parent
b546504c58
commit
939772a64e
@ -544,6 +544,7 @@ class Functional(training_lib.Model):
|
|||||||
t_rank = t_shape.rank
|
t_rank = t_shape.rank
|
||||||
ref_shape = ref_input.shape
|
ref_shape = ref_input.shape
|
||||||
ref_rank = ref_shape.rank
|
ref_rank = ref_shape.rank
|
||||||
|
keras_history = getattr(tensor, '_keras_history', None)
|
||||||
if t_rank is not None and ref_rank is not None:
|
if t_rank is not None and ref_rank is not None:
|
||||||
# Should squeeze last dimension.
|
# Should squeeze last dimension.
|
||||||
# True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
|
# True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
|
||||||
@ -553,6 +554,8 @@ class Functional(training_lib.Model):
|
|||||||
# True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
|
# True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
|
||||||
elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
|
elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
|
||||||
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
|
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
|
||||||
|
if keras_history is not None: # Restore keras history.
|
||||||
|
tensor._keras_history = keras_history
|
||||||
|
|
||||||
# Add shape hints to Tensors that may have None shape dims but have shapes
|
# Add shape hints to Tensors that may have None shape dims but have shapes
|
||||||
# defined by the `keras.Input` (not applicable in eager mode).
|
# defined by the `keras.Input` (not applicable in eager mode).
|
||||||
|
@ -2068,5 +2068,18 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||||||
# `None` value passed during construction is overridden.
|
# `None` value passed during construction is overridden.
|
||||||
self.assertAllEqual(network(x, training=False), x * 0.0)
|
self.assertAllEqual(network(x, training=False), x * 0.0)
|
||||||
|
|
||||||
|
def test_keras_history_propagation_(self):
|
||||||
|
for input_shape in [(1,), (1, 1)]:
|
||||||
|
sub_in = input_layer_lib.Input((1,))
|
||||||
|
relu_layer = layers.ReLU()
|
||||||
|
sub_out = relu_layer(sub_in)
|
||||||
|
submodel = functional.Functional(sub_in, sub_out)
|
||||||
|
self.assertLen(relu_layer._inbound_nodes, 1)
|
||||||
|
|
||||||
|
inp = input_layer_lib.Input(input_shape)
|
||||||
|
submodel(inp)
|
||||||
|
self.assertLen(relu_layer._inbound_nodes, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user