Fix issue where calling plot_model on Functional model that uses add_loss would crash due to model._layers containing DictWrapper objects.
PiperOrigin-RevId: 313237777 Change-Id: I1e9685242f3c5d887340fbcfed6f4709681c7cb7
This commit is contained in:
parent
48296300d6
commit
dd7849ed4c
|
@ -129,6 +129,7 @@ def model_to_dot(model,
|
|||
sub_w_first_node = {}
|
||||
sub_w_last_node = {}
|
||||
|
||||
layers = model.layers
|
||||
if not model._is_graph_network:
|
||||
node = pydot.Node(str(id(model)), label=model.name)
|
||||
dot.add_node(node)
|
||||
|
@ -136,7 +137,7 @@ def model_to_dot(model,
|
|||
elif isinstance(model, sequential.Sequential):
|
||||
if not model.built:
|
||||
model.build()
|
||||
layers = model._layers
|
||||
layers = super(sequential.Sequential, model).layers
|
||||
|
||||
# Create graph nodes.
|
||||
for i, layer in enumerate(layers):
|
||||
|
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||
from tensorflow.python import keras
|
||||
from tensorflow.python.keras.utils import vis_utils
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
|
@ -67,6 +68,32 @@ class ModelToDotFormatTest(test.TestCase):
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
def test_plot_model_with_add_loss(self):
|
||||
inputs = keras.Input(shape=(None, 3))
|
||||
outputs = keras.layers.Dense(1)(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
model.add_loss(math_ops.reduce_mean(outputs))
|
||||
dot_img_file = 'model_3.png'
|
||||
try:
|
||||
vis_utils.plot_model(
|
||||
model, to_file=dot_img_file, show_shapes=True, expand_nested=True)
|
||||
self.assertTrue(file_io.file_exists(dot_img_file))
|
||||
file_io.delete_file(dot_img_file)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
model = keras.Sequential([
|
||||
keras.Input(shape=(None, 3)), keras.layers.Dense(1)])
|
||||
model.add_loss(math_ops.reduce_mean(model.output))
|
||||
dot_img_file = 'model_4.png'
|
||||
try:
|
||||
vis_utils.plot_model(
|
||||
model, to_file=dot_img_file, show_shapes=True, expand_nested=True)
|
||||
self.assertTrue(file_io.file_exists(dot_img_file))
|
||||
file_io.delete_file(dot_img_file)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue