Fix config of network with disconnected input layers.
PiperOrigin-RevId: 254105005
This commit is contained in:
parent
9caacf0ac5
commit
401bbfc336
@ -1796,12 +1796,15 @@ def _map_graph_network(inputs, outputs):
|
||||
nodes_depths[inbound_node] = max(depth + 1, previous_depth)
|
||||
|
||||
# Handle inputs that are not connected to outputs.
|
||||
# We do not error out here because the inputs may be used to compute losses
|
||||
# and metrics.
|
||||
for input_t in inputs:
|
||||
input_layer = input_t._keras_history[0]
|
||||
if input_layer not in layers_depths:
|
||||
layers_depths[input_layer] = 0
|
||||
layer_indices[input_layer] = -1
|
||||
nodes_depths[input_layer._inbound_nodes[0]] = 0
|
||||
network_nodes.add(_make_node_key(input_layer.name, 0))
|
||||
|
||||
# Build a dict {depth: list of nodes with this depth}
|
||||
nodes_by_depth = collections.defaultdict(list)
|
||||
|
@ -968,6 +968,18 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
w = model.add_weight('w', [], initializer=keras.initializers.Constant(1))
|
||||
self.assertEqual(dtypes.int64, w.dtype)
|
||||
|
||||
def test_disconnected_inputs(self):
|
||||
input_tensor1 = input_layer_lib.Input(shape=[200], name='a')
|
||||
input_tensor2 = input_layer_lib.Input(shape=[10], name='b')
|
||||
output_tensor1 = keras.layers.Dense(units=10)(input_tensor1)
|
||||
|
||||
net = keras.engine.network.Network(
|
||||
inputs=[input_tensor1, input_tensor2], outputs=[output_tensor1])
|
||||
net2 = keras.engine.network.Network.from_config(net.get_config())
|
||||
self.assertLen(net2.inputs, 2)
|
||||
self.assertEqual('a', net2.layers[0].name)
|
||||
self.assertEqual('b', net2.layers[1].name)
|
||||
|
||||
|
||||
class DeferredModeTest(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user