Fix connectivity metadata missing issue on functional model.
PiperOrigin-RevId: 315381292 Change-Id: I3a7e61a0afbda0ae1984a4152dd297e350b29775
This commit is contained in:
parent
b546504c58
commit
939772a64e
@ -544,6 +544,7 @@ class Functional(training_lib.Model):
|
||||
t_rank = t_shape.rank
|
||||
ref_shape = ref_input.shape
|
||||
ref_rank = ref_shape.rank
|
||||
keras_history = getattr(tensor, '_keras_history', None)
|
||||
if t_rank is not None and ref_rank is not None:
|
||||
# Should squeeze last dimension.
|
||||
# True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
|
||||
@ -553,6 +554,8 @@ class Functional(training_lib.Model):
|
||||
# True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
|
||||
elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
|
||||
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
|
||||
if keras_history is not None: # Restore keras history.
|
||||
tensor._keras_history = keras_history
|
||||
|
||||
# Add shape hints to Tensors that may have None shape dims but have shapes
|
||||
# defined by the `keras.Input` (not applicable in eager mode).
|
||||
|
@ -2068,5 +2068,18 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||
# `None` value passed during construction is overridden.
|
||||
self.assertAllEqual(network(x, training=False), x * 0.0)
|
||||
|
||||
def test_keras_history_propagation_(self):
|
||||
for input_shape in [(1,), (1, 1)]:
|
||||
sub_in = input_layer_lib.Input((1,))
|
||||
relu_layer = layers.ReLU()
|
||||
sub_out = relu_layer(sub_in)
|
||||
submodel = functional.Functional(sub_in, sub_out)
|
||||
self.assertLen(relu_layer._inbound_nodes, 1)
|
||||
|
||||
inp = input_layer_lib.Input(input_shape)
|
||||
submodel(inp)
|
||||
self.assertLen(relu_layer._inbound_nodes, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user