fixit for get_json_type.

PiperOrigin-RevId: 322446524
Change-Id: I4301abc31713637cdc1a5992a92b4322d405806f
This commit is contained in:
Zhenyu Tan 2020-07-21 14:45:02 -07:00 committed by TensorFlower Gardener
parent cc532b863f
commit 33d842c295
7 changed files with 71 additions and 16 deletions

View File

@ -28,9 +28,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
_CONSTANT_VALUE = '_CONSTANT_VALUE'
@ -171,7 +171,7 @@ class Node(object):
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
try:
json.dumps(kwargs, default=serialization.get_json_type)
json.dumps(kwargs, default=json_utils.get_json_type)
except TypeError:
kwarg_types = nest.map_structure(type, kwargs)
raise TypeError('Layer ' + self.layer.name +

View File

@ -52,6 +52,7 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
@ -77,7 +78,6 @@ from tensorflow.python.training.tracking import layer_utils as trackable_layer_u
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
@ -2262,7 +2262,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
"""
model_config = self._updated_config()
return json.dumps(
model_config, default=serialization.get_json_type, **kwargs)
model_config, default=json_utils.get_json_type, **kwargs)
def to_yaml(self, **kwargs):
"""Returns a yaml string containing the network configuration.

View File

@ -24,7 +24,7 @@ from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
from tensorflow.python.util import serialization
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.util.tf_export import keras_export
@ -112,7 +112,7 @@ class DenseFeatures(kfc._BaseFeaturesLayer): # pylint: disable=protected-access
"""
metadata = json.loads(super(DenseFeatures, self)._tracking_metadata)
metadata['_is_feature_layer'] = True
return json.dumps(metadata, default=serialization.get_json_type)
return json.dumps(metadata, default=json_utils.get_json_type)
def _target_shape(self, input_shape, total_elements):
return (input_shape[0], total_elements)

View File

@ -35,7 +35,6 @@ from tensorflow.python.keras.utils.generic_utils import LazyLoader
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import serialization
# pylint: disable=g-import-not-at-top
try:
@ -111,7 +110,7 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
for k, v in model_metadata.items():
if isinstance(v, (dict, list, tuple)):
f.attrs[k] = json.dumps(
v, default=serialization.get_json_type).encode('utf8')
v, default=json_utils.get_json_type).encode('utf8')
else:
f.attrs[k] = v

View File

@ -26,10 +26,19 @@ from __future__ import division
from __future__ import print_function
import json
import numpy as np
import wrapt
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import serialization
try:
# This import only works on python 3.3 and above.
import collections.abc as collections_abc # pylint: disable=unused-import, g-import-not-at-top
except ImportError:
import collections as collections_abc # pylint: disable=unused-import, g-import-not-at-top
class Encoder(json.JSONEncoder):
"""JSON encoder and decoder that handles TensorShapes and tuples."""
@ -67,3 +76,50 @@ def _decode_helper(obj):
elif obj['class_name'] == '__tuple__':
return tuple(_decode_helper(i) for i in obj['items'])
return obj
def get_json_type(obj):
"""Serializes any object to a JSON-serializable structure.
Arguments:
obj: the object to serialize
Returns:
JSON-serializable structure representing `obj`.
Raises:
TypeError: if `obj` cannot be serialized.
"""
# if obj is a serializable Keras class instance
# e.g. optimizer, layer
if hasattr(obj, 'get_config'):
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
# if obj is any numpy type
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj.item()
# misc functions (e.g. loss function)
if callable(obj):
return obj.__name__
# if obj is a python 'type'
if type(obj).__name__ == type.__name__:
return obj.__name__
if isinstance(obj, tensor_shape.TensorShape):
return obj.as_list()
if isinstance(obj, dtypes.DType):
return obj.name
if isinstance(obj, collections_abc.Mapping):
return dict(obj)
if isinstance(obj, wrapt.ObjectProxy):
return obj.__wrapped__
raise TypeError('Not JSON Serializable:', obj)

View File

@ -27,8 +27,8 @@ from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.platform import test
from tensorflow.python.util import serialization
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
@ -38,7 +38,7 @@ class SerializationTests(keras_parameterized.TestCase):
dense = core.Dense(3)
dense(constant_op.constant([[4.]]))
round_trip = json.loads(json.dumps(
dense, default=serialization.get_json_type))
dense, default=json_utils.get_json_type))
self.assertEqual(3, round_trip["config"]["units"])
def test_serialize_sequential(self):
@ -47,7 +47,7 @@ class SerializationTests(keras_parameterized.TestCase):
model.add(core.Dense(5))
model(constant_op.constant([[1.]]))
sequential_round_trip = json.loads(
json.dumps(model, default=serialization.get_json_type))
json.dumps(model, default=json_utils.get_json_type))
self.assertEqual(
# Note that `config['layers'][0]` will be an InputLayer in V2
# (but not in V1)
@ -59,7 +59,7 @@ class SerializationTests(keras_parameterized.TestCase):
model = training.Model(x, y)
model(constant_op.constant([[1., 1., 1.]]))
model_round_trip = json.loads(
json.dumps(model, default=serialization.get_json_type))
json.dumps(model, default=json_utils.get_json_type))
self.assertEqual(
10, model_round_trip["config"]["layers"][1]["config"]["units"])

View File

@ -30,6 +30,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.layers import core as non_keras_core
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
@ -39,7 +40,6 @@ from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
class ListTests(test.TestCase):
@ -47,7 +47,7 @@ class ListTests(test.TestCase):
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.l = [1]
json.dumps(obj.l, default=serialization.get_json_type)
json.dumps(obj.l, default=json_utils.get_json_type)
def testNotTrackable(self):
class NotTrackable(object):
@ -337,7 +337,7 @@ class MappingTests(test.TestCase):
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.d = {"a": 2}
json.dumps(obj.d, default=serialization.get_json_type)
json.dumps(obj.d, default=json_utils.get_json_type)
def testNoOverwrite(self):
mapping = data_structures.Mapping()
@ -519,7 +519,7 @@ class TupleTests(test.TestCase, parameterized.TestCase):
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.l = (1,)
json.dumps(obj.l, default=serialization.get_json_type)
json.dumps(obj.l, default=json_utils.get_json_type)
def testNonLayerVariables(self):
v = resource_variable_ops.ResourceVariable([1.])