Fix issue where calling plot_model on Functional model that uses add_loss would crash due to model._layers containing DictWrapper objects.

PiperOrigin-RevId: 313237777
Change-Id: I1e9685242f3c5d887340fbcfed6f4709681c7cb7
This commit is contained in:
Francois Chollet 2020-05-26 11:50:13 -07:00 committed by TensorFlower Gardener
parent 48296300d6
commit dd7849ed4c
2 changed files with 29 additions and 1 deletions

View File

@ -129,6 +129,7 @@ def model_to_dot(model,
sub_w_first_node = {}
sub_w_last_node = {}
layers = model.layers
if not model._is_graph_network:
node = pydot.Node(str(id(model)), label=model.name)
dot.add_node(node)
@ -136,7 +137,7 @@ def model_to_dot(model,
elif isinstance(model, sequential.Sequential):
if not model.built:
model.build()
layers = model._layers
layers = super(sequential.Sequential, model).layers
# Create graph nodes.
for i, layer in enumerate(layers):

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python import keras
from tensorflow.python.keras.utils import vis_utils
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -67,6 +68,32 @@ class ModelToDotFormatTest(test.TestCase):
except ImportError:
pass
def test_plot_model_with_add_loss(self):
inputs = keras.Input(shape=(None, 3))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
model.add_loss(math_ops.reduce_mean(outputs))
dot_img_file = 'model_3.png'
try:
vis_utils.plot_model(
model, to_file=dot_img_file, show_shapes=True, expand_nested=True)
self.assertTrue(file_io.file_exists(dot_img_file))
file_io.delete_file(dot_img_file)
except ImportError:
pass
model = keras.Sequential([
keras.Input(shape=(None, 3)), keras.layers.Dense(1)])
model.add_loss(math_ops.reduce_mean(model.output))
dot_img_file = 'model_4.png'
try:
vis_utils.plot_model(
model, to_file=dot_img_file, show_shapes=True, expand_nested=True)
self.assertTrue(file_io.file_exists(dot_img_file))
file_io.delete_file(dot_img_file)
except ImportError:
pass
if __name__ == '__main__':
test.main()