Fix config of network with disconnected input layers.

PiperOrigin-RevId: 254105005
This commit is contained in:
Katherine Wu 2019-06-19 17:27:40 -07:00 committed by TensorFlower Gardener
parent 9caacf0ac5
commit 401bbfc336
2 changed files with 15 additions and 0 deletions

View File

@ -1796,12 +1796,15 @@ def _map_graph_network(inputs, outputs):
nodes_depths[inbound_node] = max(depth + 1, previous_depth)
# Handle inputs that are not connected to outputs.
# We do not error out here because the inputs may be used to compute losses
# and metrics.
for input_t in inputs:
input_layer = input_t._keras_history[0]
if input_layer not in layers_depths:
layers_depths[input_layer] = 0
layer_indices[input_layer] = -1
nodes_depths[input_layer._inbound_nodes[0]] = 0
network_nodes.add(_make_node_key(input_layer.name, 0))
# Build a dict {depth: list of nodes with this depth}
nodes_by_depth = collections.defaultdict(list)

View File

@ -968,6 +968,18 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
w = model.add_weight('w', [], initializer=keras.initializers.Constant(1))
self.assertEqual(dtypes.int64, w.dtype)
def test_disconnected_inputs(self):
input_tensor1 = input_layer_lib.Input(shape=[200], name='a')
input_tensor2 = input_layer_lib.Input(shape=[10], name='b')
output_tensor1 = keras.layers.Dense(units=10)(input_tensor1)
net = keras.engine.network.Network(
inputs=[input_tensor1, input_tensor2], outputs=[output_tensor1])
net2 = keras.engine.network.Network.from_config(net.get_config())
self.assertLen(net2.inputs, 2)
self.assertEqual('a', net2.layers[0].name)
self.assertEqual('b', net2.layers[1].name)
class DeferredModeTest(test.TestCase):