352 lines
12 KiB
Python
352 lines
12 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
# pylint: disable=protected-access
|
|
# pylint: disable=g-import-not-at-top
|
|
"""Utilities related to model visualization."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import sys
|
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
|
|
|
|
try:
|
|
# pydot-ng is a fork of pydot that is better maintained.
|
|
import pydot_ng as pydot
|
|
except ImportError:
|
|
# pydotplus is an improved version of pydot
|
|
try:
|
|
import pydotplus as pydot
|
|
except ImportError:
|
|
# Fall back on pydot if necessary.
|
|
try:
|
|
import pydot
|
|
except ImportError:
|
|
pydot = None
|
|
|
|
|
|
def check_pydot():
|
|
"""Returns True if PyDot and Graphviz are available."""
|
|
if pydot is None:
|
|
return False
|
|
try:
|
|
# Attempt to create an image of a blank graph
|
|
# to check the pydot/graphviz installation.
|
|
pydot.Dot.create(pydot.Dot())
|
|
return True
|
|
except (OSError, pydot.InvocationException):
|
|
return False
|
|
|
|
|
|
def is_wrapped_model(layer):
|
|
from tensorflow.python.keras.engine import functional
|
|
from tensorflow.python.keras.layers import wrappers
|
|
return (isinstance(layer, wrappers.Wrapper) and
|
|
isinstance(layer.layer, functional.Functional))
|
|
|
|
|
|
def add_edge(dot, src, dst):
|
|
if not dot.get_edge(src, dst):
|
|
dot.add_edge(pydot.Edge(src, dst))
|
|
|
|
|
|
@keras_export('keras.utils.model_to_dot')
|
|
def model_to_dot(model,
|
|
show_shapes=False,
|
|
show_dtype=False,
|
|
show_layer_names=True,
|
|
rankdir='TB',
|
|
expand_nested=False,
|
|
dpi=96,
|
|
subgraph=False):
|
|
"""Convert a Keras model to dot format.
|
|
|
|
Args:
|
|
model: A Keras model instance.
|
|
show_shapes: whether to display shape information.
|
|
show_dtype: whether to display layer dtypes.
|
|
show_layer_names: whether to display layer names.
|
|
rankdir: `rankdir` argument passed to PyDot,
|
|
a string specifying the format of the plot:
|
|
'TB' creates a vertical plot;
|
|
'LR' creates a horizontal plot.
|
|
expand_nested: whether to expand nested models into clusters.
|
|
dpi: Dots per inch.
|
|
subgraph: whether to return a `pydot.Cluster` instance.
|
|
|
|
Returns:
|
|
A `pydot.Dot` instance representing the Keras model or
|
|
a `pydot.Cluster` instance representing nested model if
|
|
`subgraph=True`.
|
|
|
|
Raises:
|
|
ImportError: if graphviz or pydot are not available.
|
|
"""
|
|
from tensorflow.python.keras.layers import wrappers
|
|
from tensorflow.python.keras.engine import sequential
|
|
from tensorflow.python.keras.engine import functional
|
|
|
|
if not check_pydot():
|
|
message = (
|
|
'You must install pydot (`pip install pydot`) '
|
|
'and install graphviz '
|
|
'(see instructions at https://graphviz.gitlab.io/download/) ',
|
|
'for plot_model/model_to_dot to work.')
|
|
if 'IPython.core.magics.namespace' in sys.modules:
|
|
# We don't raise an exception here in order to avoid crashing notebook
|
|
# tests where graphviz is not available.
|
|
print(message)
|
|
return
|
|
else:
|
|
raise ImportError(message)
|
|
|
|
if subgraph:
|
|
dot = pydot.Cluster(style='dashed', graph_name=model.name)
|
|
dot.set('label', model.name)
|
|
dot.set('labeljust', 'l')
|
|
else:
|
|
dot = pydot.Dot()
|
|
dot.set('rankdir', rankdir)
|
|
dot.set('concentrate', True)
|
|
dot.set('dpi', dpi)
|
|
dot.set_node_defaults(shape='record')
|
|
|
|
sub_n_first_node = {}
|
|
sub_n_last_node = {}
|
|
sub_w_first_node = {}
|
|
sub_w_last_node = {}
|
|
|
|
layers = model.layers
|
|
if not model._is_graph_network:
|
|
node = pydot.Node(str(id(model)), label=model.name)
|
|
dot.add_node(node)
|
|
return dot
|
|
elif isinstance(model, sequential.Sequential):
|
|
if not model.built:
|
|
model.build()
|
|
layers = super(sequential.Sequential, model).layers
|
|
|
|
# Create graph nodes.
|
|
for i, layer in enumerate(layers):
|
|
layer_id = str(id(layer))
|
|
|
|
# Append a wrapped layer's label to node's label, if it exists.
|
|
layer_name = layer.name
|
|
class_name = layer.__class__.__name__
|
|
|
|
if isinstance(layer, wrappers.Wrapper):
|
|
if expand_nested and isinstance(layer.layer,
|
|
functional.Functional):
|
|
submodel_wrapper = model_to_dot(
|
|
layer.layer,
|
|
show_shapes,
|
|
show_dtype,
|
|
show_layer_names,
|
|
rankdir,
|
|
expand_nested,
|
|
subgraph=True)
|
|
# sub_w : submodel_wrapper
|
|
sub_w_nodes = submodel_wrapper.get_nodes()
|
|
sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
|
|
sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
|
|
dot.add_subgraph(submodel_wrapper)
|
|
else:
|
|
layer_name = '{}({})'.format(layer_name, layer.layer.name)
|
|
child_class_name = layer.layer.__class__.__name__
|
|
class_name = '{}({})'.format(class_name, child_class_name)
|
|
|
|
if expand_nested and isinstance(layer, functional.Functional):
|
|
submodel_not_wrapper = model_to_dot(
|
|
layer,
|
|
show_shapes,
|
|
show_dtype,
|
|
show_layer_names,
|
|
rankdir,
|
|
expand_nested,
|
|
subgraph=True)
|
|
# sub_n : submodel_not_wrapper
|
|
sub_n_nodes = submodel_not_wrapper.get_nodes()
|
|
sub_n_first_node[layer.name] = sub_n_nodes[0]
|
|
sub_n_last_node[layer.name] = sub_n_nodes[-1]
|
|
dot.add_subgraph(submodel_not_wrapper)
|
|
|
|
# Create node's label.
|
|
if show_layer_names:
|
|
label = '{}: {}'.format(layer_name, class_name)
|
|
else:
|
|
label = class_name
|
|
|
|
# Rebuild the label as a table including the layer's dtype.
|
|
if show_dtype:
|
|
|
|
def format_dtype(dtype):
|
|
if dtype is None:
|
|
return '?'
|
|
else:
|
|
return str(dtype)
|
|
|
|
label = '%s|%s' % (label, format_dtype(layer.dtype))
|
|
|
|
# Rebuild the label as a table including input/output shapes.
|
|
if show_shapes:
|
|
|
|
def format_shape(shape):
|
|
return str(shape).replace(str(None), 'None')
|
|
|
|
try:
|
|
outputlabels = format_shape(layer.output_shape)
|
|
except AttributeError:
|
|
outputlabels = '?'
|
|
if hasattr(layer, 'input_shape'):
|
|
inputlabels = format_shape(layer.input_shape)
|
|
elif hasattr(layer, 'input_shapes'):
|
|
inputlabels = ', '.join(
|
|
[format_shape(ishape) for ishape in layer.input_shapes])
|
|
else:
|
|
inputlabels = '?'
|
|
label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
|
|
inputlabels,
|
|
outputlabels)
|
|
|
|
if not expand_nested or not isinstance(
|
|
layer, functional.Functional):
|
|
node = pydot.Node(layer_id, label=label)
|
|
dot.add_node(node)
|
|
|
|
# Connect nodes with edges.
|
|
for layer in layers:
|
|
layer_id = str(id(layer))
|
|
for i, node in enumerate(layer._inbound_nodes):
|
|
node_key = layer.name + '_ib-' + str(i)
|
|
if node_key in model._network_nodes:
|
|
for inbound_layer in nest.flatten(node.inbound_layers):
|
|
inbound_layer_id = str(id(inbound_layer))
|
|
if not expand_nested:
|
|
assert dot.get_node(inbound_layer_id)
|
|
assert dot.get_node(layer_id)
|
|
add_edge(dot, inbound_layer_id, layer_id)
|
|
else:
|
|
# if inbound_layer is not Model or wrapped Model
|
|
if (not isinstance(inbound_layer,
|
|
functional.Functional) and
|
|
not is_wrapped_model(inbound_layer)):
|
|
# if current layer is not Model or wrapped Model
|
|
if (not isinstance(layer, functional.Functional) and
|
|
not is_wrapped_model(layer)):
|
|
assert dot.get_node(inbound_layer_id)
|
|
assert dot.get_node(layer_id)
|
|
add_edge(dot, inbound_layer_id, layer_id)
|
|
# if current layer is Model
|
|
elif isinstance(layer, functional.Functional):
|
|
add_edge(dot, inbound_layer_id,
|
|
sub_n_first_node[layer.name].get_name())
|
|
# if current layer is wrapped Model
|
|
elif is_wrapped_model(layer):
|
|
add_edge(dot, inbound_layer_id, layer_id)
|
|
name = sub_w_first_node[layer.layer.name].get_name()
|
|
add_edge(dot, layer_id, name)
|
|
# if inbound_layer is Model
|
|
elif isinstance(inbound_layer, functional.Functional):
|
|
name = sub_n_last_node[inbound_layer.name].get_name()
|
|
if isinstance(layer, functional.Functional):
|
|
output_name = sub_n_first_node[layer.name].get_name()
|
|
add_edge(dot, name, output_name)
|
|
else:
|
|
add_edge(dot, name, layer_id)
|
|
# if inbound_layer is wrapped Model
|
|
elif is_wrapped_model(inbound_layer):
|
|
inbound_layer_name = inbound_layer.layer.name
|
|
add_edge(dot,
|
|
sub_w_last_node[inbound_layer_name].get_name(),
|
|
layer_id)
|
|
return dot
|
|
|
|
|
|
@keras_export('keras.utils.plot_model')
|
|
def plot_model(model,
|
|
to_file='model.png',
|
|
show_shapes=False,
|
|
show_dtype=False,
|
|
show_layer_names=True,
|
|
rankdir='TB',
|
|
expand_nested=False,
|
|
dpi=96):
|
|
"""Converts a Keras model to dot format and save to a file.
|
|
|
|
Example:
|
|
|
|
```python
|
|
input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
|
|
x = tf.keras.layers.Embedding(
|
|
output_dim=512, input_dim=10000, input_length=100)(input)
|
|
x = tf.keras.layers.LSTM(32)(x)
|
|
x = tf.keras.layers.Dense(64, activation='relu')(x)
|
|
x = tf.keras.layers.Dense(64, activation='relu')(x)
|
|
x = tf.keras.layers.Dense(64, activation='relu')(x)
|
|
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
|
|
model = tf.keras.Model(inputs=[input], outputs=[output])
|
|
dot_img_file = '/tmp/model_1.png'
|
|
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
|
|
```
|
|
|
|
Args:
|
|
model: A Keras model instance
|
|
to_file: File name of the plot image.
|
|
show_shapes: whether to display shape information.
|
|
show_dtype: whether to display layer dtypes.
|
|
show_layer_names: whether to display layer names.
|
|
rankdir: `rankdir` argument passed to PyDot,
|
|
a string specifying the format of the plot:
|
|
'TB' creates a vertical plot;
|
|
'LR' creates a horizontal plot.
|
|
expand_nested: Whether to expand nested models into clusters.
|
|
dpi: Dots per inch.
|
|
|
|
Returns:
|
|
A Jupyter notebook Image object if Jupyter is installed.
|
|
This enables in-line display of the model plots in notebooks.
|
|
"""
|
|
dot = model_to_dot(
|
|
model,
|
|
show_shapes=show_shapes,
|
|
show_dtype=show_dtype,
|
|
show_layer_names=show_layer_names,
|
|
rankdir=rankdir,
|
|
expand_nested=expand_nested,
|
|
dpi=dpi)
|
|
to_file = path_to_string(to_file)
|
|
if dot is None:
|
|
return
|
|
_, extension = os.path.splitext(to_file)
|
|
if not extension:
|
|
extension = 'png'
|
|
else:
|
|
extension = extension[1:]
|
|
# Save image to disk.
|
|
dot.write(to_file, format=extension)
|
|
# Return the image as a Jupyter Image object, to be displayed in-line.
|
|
# Note that we cannot easily detect whether the code is running in a
|
|
# notebook, and thus we always return the Image if Jupyter is available.
|
|
if extension != 'pdf':
|
|
try:
|
|
from IPython import display
|
|
return display.Image(filename=to_file)
|
|
except ImportError:
|
|
pass
|