fixit for get_json_type.
PiperOrigin-RevId: 322446524 Change-Id: I4301abc31713637cdc1a5992a92b4322d405806f
This commit is contained in:
parent
cc532b863f
commit
33d842c295
@ -28,9 +28,9 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
_CONSTANT_VALUE = '_CONSTANT_VALUE'
|
||||
|
||||
@ -171,7 +171,7 @@ class Node(object):
|
||||
|
||||
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
|
||||
try:
|
||||
json.dumps(kwargs, default=serialization.get_json_type)
|
||||
json.dumps(kwargs, default=json_utils.get_json_type)
|
||||
except TypeError:
|
||||
kwarg_types = nest.map_structure(type, kwargs)
|
||||
raise TypeError('Layer ' + self.layer.name +
|
||||
|
@ -52,6 +52,7 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
@ -77,7 +78,6 @@ from tensorflow.python.training.tracking import layer_utils as trackable_layer_u
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
@ -2262,7 +2262,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
"""
|
||||
model_config = self._updated_config()
|
||||
return json.dumps(
|
||||
model_config, default=serialization.get_json_type, **kwargs)
|
||||
model_config, default=json_utils.get_json_type, **kwargs)
|
||||
|
||||
def to_yaml(self, **kwargs):
|
||||
"""Returns a yaml string containing the network configuration.
|
||||
|
@ -24,7 +24,7 @@ from tensorflow.python.feature_column import feature_column_v2 as fc
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
|
||||
from tensorflow.python.util import serialization
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ class DenseFeatures(kfc._BaseFeaturesLayer): # pylint: disable=protected-access
|
||||
"""
|
||||
metadata = json.loads(super(DenseFeatures, self)._tracking_metadata)
|
||||
metadata['_is_feature_layer'] = True
|
||||
return json.dumps(metadata, default=serialization.get_json_type)
|
||||
return json.dumps(metadata, default=json_utils.get_json_type)
|
||||
|
||||
def _target_shape(self, input_shape, total_elements):
|
||||
return (input_shape[0], total_elements)
|
||||
|
@ -35,7 +35,6 @@ from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
@ -111,7 +110,7 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
||||
for k, v in model_metadata.items():
|
||||
if isinstance(v, (dict, list, tuple)):
|
||||
f.attrs[k] = json.dumps(
|
||||
v, default=serialization.get_json_type).encode('utf8')
|
||||
v, default=json_utils.get_json_type).encode('utf8')
|
||||
else:
|
||||
f.attrs[k] = v
|
||||
|
||||
|
@ -26,10 +26,19 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import wrapt
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
try:
|
||||
# This import only works on python 3.3 and above.
|
||||
import collections.abc as collections_abc # pylint: disable=unused-import, g-import-not-at-top
|
||||
except ImportError:
|
||||
import collections as collections_abc # pylint: disable=unused-import, g-import-not-at-top
|
||||
|
||||
|
||||
class Encoder(json.JSONEncoder):
|
||||
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
||||
@ -67,3 +76,50 @@ def _decode_helper(obj):
|
||||
elif obj['class_name'] == '__tuple__':
|
||||
return tuple(_decode_helper(i) for i in obj['items'])
|
||||
return obj
|
||||
|
||||
|
||||
def get_json_type(obj):
|
||||
"""Serializes any object to a JSON-serializable structure.
|
||||
|
||||
Arguments:
|
||||
obj: the object to serialize
|
||||
|
||||
Returns:
|
||||
JSON-serializable structure representing `obj`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `obj` cannot be serialized.
|
||||
"""
|
||||
# if obj is a serializable Keras class instance
|
||||
# e.g. optimizer, layer
|
||||
if hasattr(obj, 'get_config'):
|
||||
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
|
||||
|
||||
# if obj is any numpy type
|
||||
if type(obj).__module__ == np.__name__:
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
else:
|
||||
return obj.item()
|
||||
|
||||
# misc functions (e.g. loss function)
|
||||
if callable(obj):
|
||||
return obj.__name__
|
||||
|
||||
# if obj is a python 'type'
|
||||
if type(obj).__name__ == type.__name__:
|
||||
return obj.__name__
|
||||
|
||||
if isinstance(obj, tensor_shape.TensorShape):
|
||||
return obj.as_list()
|
||||
|
||||
if isinstance(obj, dtypes.DType):
|
||||
return obj.name
|
||||
|
||||
if isinstance(obj, collections_abc.Mapping):
|
||||
return dict(obj)
|
||||
|
||||
if isinstance(obj, wrapt.ObjectProxy):
|
||||
return obj.__wrapped__
|
||||
|
||||
raise TypeError('Not JSON Serializable:', obj)
|
||||
|
@ -27,8 +27,8 @@ from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
||||
@ -38,7 +38,7 @@ class SerializationTests(keras_parameterized.TestCase):
|
||||
dense = core.Dense(3)
|
||||
dense(constant_op.constant([[4.]]))
|
||||
round_trip = json.loads(json.dumps(
|
||||
dense, default=serialization.get_json_type))
|
||||
dense, default=json_utils.get_json_type))
|
||||
self.assertEqual(3, round_trip["config"]["units"])
|
||||
|
||||
def test_serialize_sequential(self):
|
||||
@ -47,7 +47,7 @@ class SerializationTests(keras_parameterized.TestCase):
|
||||
model.add(core.Dense(5))
|
||||
model(constant_op.constant([[1.]]))
|
||||
sequential_round_trip = json.loads(
|
||||
json.dumps(model, default=serialization.get_json_type))
|
||||
json.dumps(model, default=json_utils.get_json_type))
|
||||
self.assertEqual(
|
||||
# Note that `config['layers'][0]` will be an InputLayer in V2
|
||||
# (but not in V1)
|
||||
@ -59,7 +59,7 @@ class SerializationTests(keras_parameterized.TestCase):
|
||||
model = training.Model(x, y)
|
||||
model(constant_op.constant([[1., 1., 1.]]))
|
||||
model_round_trip = json.loads(
|
||||
json.dumps(model, default=serialization.get_json_type))
|
||||
json.dumps(model, default=json_utils.get_json_type))
|
||||
self.assertEqual(
|
||||
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
||||
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.layers import core as non_keras_core
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -39,7 +40,6 @@ from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
class ListTests(test.TestCase):
|
||||
@ -47,7 +47,7 @@ class ListTests(test.TestCase):
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.l = [1]
|
||||
json.dumps(obj.l, default=serialization.get_json_type)
|
||||
json.dumps(obj.l, default=json_utils.get_json_type)
|
||||
|
||||
def testNotTrackable(self):
|
||||
class NotTrackable(object):
|
||||
@ -337,7 +337,7 @@ class MappingTests(test.TestCase):
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.d = {"a": 2}
|
||||
json.dumps(obj.d, default=serialization.get_json_type)
|
||||
json.dumps(obj.d, default=json_utils.get_json_type)
|
||||
|
||||
def testNoOverwrite(self):
|
||||
mapping = data_structures.Mapping()
|
||||
@ -519,7 +519,7 @@ class TupleTests(test.TestCase, parameterized.TestCase):
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.l = (1,)
|
||||
json.dumps(obj.l, default=serialization.get_json_type)
|
||||
json.dumps(obj.l, default=json_utils.get_json_type)
|
||||
|
||||
def testNonLayerVariables(self):
|
||||
v = resource_variable_ops.ResourceVariable([1.])
|
||||
|
Loading…
x
Reference in New Issue
Block a user