fixit for get_json_type.

PiperOrigin-RevId: 322446524
Change-Id: I4301abc31713637cdc1a5992a92b4322d405806f
This commit is contained in:
Zhenyu Tan 2020-07-21 14:45:02 -07:00 committed by TensorFlower Gardener
parent cc532b863f
commit 33d842c295
7 changed files with 71 additions and 16 deletions

View File

@ -28,9 +28,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import serialization
_CONSTANT_VALUE = '_CONSTANT_VALUE' _CONSTANT_VALUE = '_CONSTANT_VALUE'
@ -171,7 +171,7 @@ class Node(object):
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs) kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
try: try:
json.dumps(kwargs, default=serialization.get_json_type) json.dumps(kwargs, default=json_utils.get_json_type)
except TypeError: except TypeError:
kwarg_types = nest.map_structure(type, kwargs) kwarg_types = nest.map_structure(type, kwargs)
raise TypeError('Layer ' + self.layer.name + raise TypeError('Layer ' + self.layer.name +

View File

@ -52,6 +52,7 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving import save
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
@ -77,7 +78,6 @@ from tensorflow.python.training.tracking import layer_utils as trackable_layer_u
from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import serialization
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import doc_controls
@ -2262,7 +2262,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
""" """
model_config = self._updated_config() model_config = self._updated_config()
return json.dumps( return json.dumps(
model_config, default=serialization.get_json_type, **kwargs) model_config, default=json_utils.get_json_type, **kwargs)
def to_yaml(self, **kwargs): def to_yaml(self, **kwargs):
"""Returns a yaml string containing the network configuration. """Returns a yaml string containing the network configuration.

View File

@ -24,7 +24,7 @@ from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras.feature_column import base_feature_layer as kfc from tensorflow.python.keras.feature_column import base_feature_layer as kfc
from tensorflow.python.util import serialization from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -112,7 +112,7 @@ class DenseFeatures(kfc._BaseFeaturesLayer): # pylint: disable=protected-access
""" """
metadata = json.loads(super(DenseFeatures, self)._tracking_metadata) metadata = json.loads(super(DenseFeatures, self)._tracking_metadata)
metadata['_is_feature_layer'] = True metadata['_is_feature_layer'] = True
return json.dumps(metadata, default=serialization.get_json_type) return json.dumps(metadata, default=json_utils.get_json_type)
def _target_shape(self, input_shape, total_elements): def _target_shape(self, input_shape, total_elements):
return (input_shape[0], total_elements) return (input_shape[0], total_elements)

View File

@ -35,7 +35,6 @@ from tensorflow.python.keras.utils.generic_utils import LazyLoader
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops import variables as variables_module from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import serialization
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
try: try:
@ -111,7 +110,7 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
for k, v in model_metadata.items(): for k, v in model_metadata.items():
if isinstance(v, (dict, list, tuple)): if isinstance(v, (dict, list, tuple)):
f.attrs[k] = json.dumps( f.attrs[k] = json.dumps(
v, default=serialization.get_json_type).encode('utf8') v, default=json_utils.get_json_type).encode('utf8')
else: else:
f.attrs[k] = v f.attrs[k] = v

View File

@ -26,10 +26,19 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import numpy as np
import wrapt
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import serialization from tensorflow.python.util import serialization
try:
# This import only works on python 3.3 and above.
import collections.abc as collections_abc # pylint: disable=unused-import, g-import-not-at-top
except ImportError:
import collections as collections_abc # pylint: disable=unused-import, g-import-not-at-top
class Encoder(json.JSONEncoder): class Encoder(json.JSONEncoder):
"""JSON encoder and decoder that handles TensorShapes and tuples.""" """JSON encoder and decoder that handles TensorShapes and tuples."""
@ -67,3 +76,50 @@ def _decode_helper(obj):
elif obj['class_name'] == '__tuple__': elif obj['class_name'] == '__tuple__':
return tuple(_decode_helper(i) for i in obj['items']) return tuple(_decode_helper(i) for i in obj['items'])
return obj return obj
def get_json_type(obj):
"""Serializes any object to a JSON-serializable structure.
Arguments:
obj: the object to serialize
Returns:
JSON-serializable structure representing `obj`.
Raises:
TypeError: if `obj` cannot be serialized.
"""
# if obj is a serializable Keras class instance
# e.g. optimizer, layer
if hasattr(obj, 'get_config'):
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
# if obj is any numpy type
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj.item()
# misc functions (e.g. loss function)
if callable(obj):
return obj.__name__
# if obj is a python 'type'
if type(obj).__name__ == type.__name__:
return obj.__name__
if isinstance(obj, tensor_shape.TensorShape):
return obj.as_list()
if isinstance(obj, dtypes.DType):
return obj.name
if isinstance(obj, collections_abc.Mapping):
return dict(obj)
if isinstance(obj, wrapt.ObjectProxy):
return obj.__wrapped__
raise TypeError('Not JSON Serializable:', obj)

View File

@ -27,8 +27,8 @@ from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import core
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import serialization
@combinations.generate(combinations.combine(mode=["graph", "eager"])) @combinations.generate(combinations.combine(mode=["graph", "eager"]))
@ -38,7 +38,7 @@ class SerializationTests(keras_parameterized.TestCase):
dense = core.Dense(3) dense = core.Dense(3)
dense(constant_op.constant([[4.]])) dense(constant_op.constant([[4.]]))
round_trip = json.loads(json.dumps( round_trip = json.loads(json.dumps(
dense, default=serialization.get_json_type)) dense, default=json_utils.get_json_type))
self.assertEqual(3, round_trip["config"]["units"]) self.assertEqual(3, round_trip["config"]["units"])
def test_serialize_sequential(self): def test_serialize_sequential(self):
@ -47,7 +47,7 @@ class SerializationTests(keras_parameterized.TestCase):
model.add(core.Dense(5)) model.add(core.Dense(5))
model(constant_op.constant([[1.]])) model(constant_op.constant([[1.]]))
sequential_round_trip = json.loads( sequential_round_trip = json.loads(
json.dumps(model, default=serialization.get_json_type)) json.dumps(model, default=json_utils.get_json_type))
self.assertEqual( self.assertEqual(
# Note that `config['layers'][0]` will be an InputLayer in V2 # Note that `config['layers'][0]` will be an InputLayer in V2
# (but not in V1) # (but not in V1)
@ -59,7 +59,7 @@ class SerializationTests(keras_parameterized.TestCase):
model = training.Model(x, y) model = training.Model(x, y)
model(constant_op.constant([[1., 1., 1.]])) model(constant_op.constant([[1., 1., 1.]]))
model_round_trip = json.loads( model_round_trip = json.loads(
json.dumps(model, default=serialization.get_json_type)) json.dumps(model, default=json_utils.get_json_type))
self.assertEqual( self.assertEqual(
10, model_round_trip["config"]["layers"][1]["config"]["units"]) 10, model_round_trip["config"]["layers"][1]["config"]["units"])

View File

@ -30,6 +30,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.layers import core as non_keras_core from tensorflow.python.layers import core as non_keras_core
from tensorflow.python.module import module from tensorflow.python.module import module
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -39,7 +40,6 @@ from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import serialization
class ListTests(test.TestCase): class ListTests(test.TestCase):
@ -47,7 +47,7 @@ class ListTests(test.TestCase):
def testJSONSerialization(self): def testJSONSerialization(self):
obj = tracking.AutoTrackable() obj = tracking.AutoTrackable()
obj.l = [1] obj.l = [1]
json.dumps(obj.l, default=serialization.get_json_type) json.dumps(obj.l, default=json_utils.get_json_type)
def testNotTrackable(self): def testNotTrackable(self):
class NotTrackable(object): class NotTrackable(object):
@ -337,7 +337,7 @@ class MappingTests(test.TestCase):
def testJSONSerialization(self): def testJSONSerialization(self):
obj = tracking.AutoTrackable() obj = tracking.AutoTrackable()
obj.d = {"a": 2} obj.d = {"a": 2}
json.dumps(obj.d, default=serialization.get_json_type) json.dumps(obj.d, default=json_utils.get_json_type)
def testNoOverwrite(self): def testNoOverwrite(self):
mapping = data_structures.Mapping() mapping = data_structures.Mapping()
@ -519,7 +519,7 @@ class TupleTests(test.TestCase, parameterized.TestCase):
def testJSONSerialization(self): def testJSONSerialization(self):
obj = tracking.AutoTrackable() obj = tracking.AutoTrackable()
obj.l = (1,) obj.l = (1,)
json.dumps(obj.l, default=serialization.get_json_type) json.dumps(obj.l, default=json_utils.get_json_type)
def testNonLayerVariables(self): def testNonLayerVariables(self):
v = resource_variable_ops.ResourceVariable([1.]) v = resource_variable_ops.ResourceVariable([1.])