For compatibility with PyYAML v5+, try to use unsafe_load first and fallback to load.

Note that this retains an existing vulnerability to arbitrary code execution that existed in previous versions of PyYAML (and transitively in Keras):
https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation

PiperOrigin-RevId: 302948409
Change-Id: Idb21aa27cc3ee02c30205f05e91d3f44f64ffa64
This commit is contained in:
David Kao 2020-03-25 12:42:43 -07:00 committed by TensorFlower Gardener
parent 0349c8ddf5
commit f43961dec7

View File

@ -85,7 +85,13 @@ def model_from_yaml(yaml_string, custom_objects=None):
"""
if yaml is None:
raise ImportError('Requires yaml module installed (`pip install pyyaml`).')
config = yaml.load(yaml_string)
# The method unsafe_load only exists in PyYAML 5.x+, so which branch of the
# try block is covered by tests depends on the installed version of PyYAML.
try:
# PyYAML 5.x+
config = yaml.unsafe_load(yaml_string)
except AttributeError:
config = yaml.load(yaml_string)
from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
return deserialize(config, custom_objects=custom_objects)