Fix config of network with disconnected input layers.
PiperOrigin-RevId: 254105005
This commit is contained in:
parent
9caacf0ac5
commit
401bbfc336
@ -1796,12 +1796,15 @@ def _map_graph_network(inputs, outputs):
|
|||||||
nodes_depths[inbound_node] = max(depth + 1, previous_depth)
|
nodes_depths[inbound_node] = max(depth + 1, previous_depth)
|
||||||
|
|
||||||
# Handle inputs that are not connected to outputs.
|
# Handle inputs that are not connected to outputs.
|
||||||
|
# We do not error out here because the inputs may be used to compute losses
|
||||||
|
# and metrics.
|
||||||
for input_t in inputs:
|
for input_t in inputs:
|
||||||
input_layer = input_t._keras_history[0]
|
input_layer = input_t._keras_history[0]
|
||||||
if input_layer not in layers_depths:
|
if input_layer not in layers_depths:
|
||||||
layers_depths[input_layer] = 0
|
layers_depths[input_layer] = 0
|
||||||
layer_indices[input_layer] = -1
|
layer_indices[input_layer] = -1
|
||||||
nodes_depths[input_layer._inbound_nodes[0]] = 0
|
nodes_depths[input_layer._inbound_nodes[0]] = 0
|
||||||
|
network_nodes.add(_make_node_key(input_layer.name, 0))
|
||||||
|
|
||||||
# Build a dict {depth: list of nodes with this depth}
|
# Build a dict {depth: list of nodes with this depth}
|
||||||
nodes_by_depth = collections.defaultdict(list)
|
nodes_by_depth = collections.defaultdict(list)
|
||||||
|
@ -968,6 +968,18 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
w = model.add_weight('w', [], initializer=keras.initializers.Constant(1))
|
w = model.add_weight('w', [], initializer=keras.initializers.Constant(1))
|
||||||
self.assertEqual(dtypes.int64, w.dtype)
|
self.assertEqual(dtypes.int64, w.dtype)
|
||||||
|
|
||||||
|
def test_disconnected_inputs(self):
|
||||||
|
input_tensor1 = input_layer_lib.Input(shape=[200], name='a')
|
||||||
|
input_tensor2 = input_layer_lib.Input(shape=[10], name='b')
|
||||||
|
output_tensor1 = keras.layers.Dense(units=10)(input_tensor1)
|
||||||
|
|
||||||
|
net = keras.engine.network.Network(
|
||||||
|
inputs=[input_tensor1, input_tensor2], outputs=[output_tensor1])
|
||||||
|
net2 = keras.engine.network.Network.from_config(net.get_config())
|
||||||
|
self.assertLen(net2.inputs, 2)
|
||||||
|
self.assertEqual('a', net2.layers[0].name)
|
||||||
|
self.assertEqual('b', net2.layers[1].name)
|
||||||
|
|
||||||
|
|
||||||
class DeferredModeTest(test.TestCase):
|
class DeferredModeTest(test.TestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user