fixit for get_json_type.
PiperOrigin-RevId: 322446524 Change-Id: I4301abc31713637cdc1a5992a92b4322d405806f
This commit is contained in:
parent
cc532b863f
commit
33d842c295
@ -28,9 +28,9 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.engine import keras_tensor
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import serialization
|
|
||||||
|
|
||||||
_CONSTANT_VALUE = '_CONSTANT_VALUE'
|
_CONSTANT_VALUE = '_CONSTANT_VALUE'
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ class Node(object):
|
|||||||
|
|
||||||
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
|
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
|
||||||
try:
|
try:
|
||||||
json.dumps(kwargs, default=serialization.get_json_type)
|
json.dumps(kwargs, default=json_utils.get_json_type)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
kwarg_types = nest.map_structure(type, kwargs)
|
kwarg_types = nest.map_structure(type, kwargs)
|
||||||
raise TypeError('Layer ' + self.layer.name +
|
raise TypeError('Layer ' + self.layer.name +
|
||||||
|
@ -52,6 +52,7 @@ from tensorflow.python.keras.engine import training_utils
|
|||||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
|
||||||
from tensorflow.python.keras.saving import hdf5_format
|
from tensorflow.python.keras.saving import hdf5_format
|
||||||
from tensorflow.python.keras.saving import save
|
from tensorflow.python.keras.saving import save
|
||||||
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
@ -77,7 +78,6 @@ from tensorflow.python.training.tracking import layer_utils as trackable_layer_u
|
|||||||
from tensorflow.python.training.tracking import util as trackable_utils
|
from tensorflow.python.training.tracking import util as trackable_utils
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import serialization
|
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
@ -2262,7 +2262,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||||||
"""
|
"""
|
||||||
model_config = self._updated_config()
|
model_config = self._updated_config()
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
model_config, default=serialization.get_json_type, **kwargs)
|
model_config, default=json_utils.get_json_type, **kwargs)
|
||||||
|
|
||||||
def to_yaml(self, **kwargs):
|
def to_yaml(self, **kwargs):
|
||||||
"""Returns a yaml string containing the network configuration.
|
"""Returns a yaml string containing the network configuration.
|
||||||
|
@ -24,7 +24,7 @@ from tensorflow.python.feature_column import feature_column_v2 as fc
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
|
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
|
||||||
from tensorflow.python.util import serialization
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ class DenseFeatures(kfc._BaseFeaturesLayer): # pylint: disable=protected-access
|
|||||||
"""
|
"""
|
||||||
metadata = json.loads(super(DenseFeatures, self)._tracking_metadata)
|
metadata = json.loads(super(DenseFeatures, self)._tracking_metadata)
|
||||||
metadata['_is_feature_layer'] = True
|
metadata['_is_feature_layer'] = True
|
||||||
return json.dumps(metadata, default=serialization.get_json_type)
|
return json.dumps(metadata, default=json_utils.get_json_type)
|
||||||
|
|
||||||
def _target_shape(self, input_shape, total_elements):
|
def _target_shape(self, input_shape, total_elements):
|
||||||
return (input_shape[0], total_elements)
|
return (input_shape[0], total_elements)
|
||||||
|
@ -35,7 +35,6 @@ from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
|||||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||||
from tensorflow.python.ops import variables as variables_module
|
from tensorflow.python.ops import variables as variables_module
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import serialization
|
|
||||||
|
|
||||||
# pylint: disable=g-import-not-at-top
|
# pylint: disable=g-import-not-at-top
|
||||||
try:
|
try:
|
||||||
@ -111,7 +110,7 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
|||||||
for k, v in model_metadata.items():
|
for k, v in model_metadata.items():
|
||||||
if isinstance(v, (dict, list, tuple)):
|
if isinstance(v, (dict, list, tuple)):
|
||||||
f.attrs[k] = json.dumps(
|
f.attrs[k] = json.dumps(
|
||||||
v, default=serialization.get_json_type).encode('utf8')
|
v, default=json_utils.get_json_type).encode('utf8')
|
||||||
else:
|
else:
|
||||||
f.attrs[k] = v
|
f.attrs[k] = v
|
||||||
|
|
||||||
|
@ -26,10 +26,19 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import wrapt
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.util import serialization
|
from tensorflow.python.util import serialization
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This import only works on python 3.3 and above.
|
||||||
|
import collections.abc as collections_abc # pylint: disable=unused-import, g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
import collections as collections_abc # pylint: disable=unused-import, g-import-not-at-top
|
||||||
|
|
||||||
|
|
||||||
class Encoder(json.JSONEncoder):
|
class Encoder(json.JSONEncoder):
|
||||||
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
||||||
@ -67,3 +76,50 @@ def _decode_helper(obj):
|
|||||||
elif obj['class_name'] == '__tuple__':
|
elif obj['class_name'] == '__tuple__':
|
||||||
return tuple(_decode_helper(i) for i in obj['items'])
|
return tuple(_decode_helper(i) for i in obj['items'])
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_type(obj):
|
||||||
|
"""Serializes any object to a JSON-serializable structure.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
obj: the object to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-serializable structure representing `obj`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `obj` cannot be serialized.
|
||||||
|
"""
|
||||||
|
# if obj is a serializable Keras class instance
|
||||||
|
# e.g. optimizer, layer
|
||||||
|
if hasattr(obj, 'get_config'):
|
||||||
|
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
|
||||||
|
|
||||||
|
# if obj is any numpy type
|
||||||
|
if type(obj).__module__ == np.__name__:
|
||||||
|
if isinstance(obj, np.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
else:
|
||||||
|
return obj.item()
|
||||||
|
|
||||||
|
# misc functions (e.g. loss function)
|
||||||
|
if callable(obj):
|
||||||
|
return obj.__name__
|
||||||
|
|
||||||
|
# if obj is a python 'type'
|
||||||
|
if type(obj).__name__ == type.__name__:
|
||||||
|
return obj.__name__
|
||||||
|
|
||||||
|
if isinstance(obj, tensor_shape.TensorShape):
|
||||||
|
return obj.as_list()
|
||||||
|
|
||||||
|
if isinstance(obj, dtypes.DType):
|
||||||
|
return obj.name
|
||||||
|
|
||||||
|
if isinstance(obj, collections_abc.Mapping):
|
||||||
|
return dict(obj)
|
||||||
|
|
||||||
|
if isinstance(obj, wrapt.ObjectProxy):
|
||||||
|
return obj.__wrapped__
|
||||||
|
|
||||||
|
raise TypeError('Not JSON Serializable:', obj)
|
||||||
|
@ -27,8 +27,8 @@ from tensorflow.python.keras.engine import input_layer
|
|||||||
from tensorflow.python.keras.engine import sequential
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import serialization
|
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
|
||||||
@ -38,7 +38,7 @@ class SerializationTests(keras_parameterized.TestCase):
|
|||||||
dense = core.Dense(3)
|
dense = core.Dense(3)
|
||||||
dense(constant_op.constant([[4.]]))
|
dense(constant_op.constant([[4.]]))
|
||||||
round_trip = json.loads(json.dumps(
|
round_trip = json.loads(json.dumps(
|
||||||
dense, default=serialization.get_json_type))
|
dense, default=json_utils.get_json_type))
|
||||||
self.assertEqual(3, round_trip["config"]["units"])
|
self.assertEqual(3, round_trip["config"]["units"])
|
||||||
|
|
||||||
def test_serialize_sequential(self):
|
def test_serialize_sequential(self):
|
||||||
@ -47,7 +47,7 @@ class SerializationTests(keras_parameterized.TestCase):
|
|||||||
model.add(core.Dense(5))
|
model.add(core.Dense(5))
|
||||||
model(constant_op.constant([[1.]]))
|
model(constant_op.constant([[1.]]))
|
||||||
sequential_round_trip = json.loads(
|
sequential_round_trip = json.loads(
|
||||||
json.dumps(model, default=serialization.get_json_type))
|
json.dumps(model, default=json_utils.get_json_type))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
# Note that `config['layers'][0]` will be an InputLayer in V2
|
# Note that `config['layers'][0]` will be an InputLayer in V2
|
||||||
# (but not in V1)
|
# (but not in V1)
|
||||||
@ -59,7 +59,7 @@ class SerializationTests(keras_parameterized.TestCase):
|
|||||||
model = training.Model(x, y)
|
model = training.Model(x, y)
|
||||||
model(constant_op.constant([[1., 1., 1.]]))
|
model(constant_op.constant([[1., 1., 1.]]))
|
||||||
model_round_trip = json.loads(
|
model_round_trip = json.loads(
|
||||||
json.dumps(model, default=serialization.get_json_type))
|
json.dumps(model, default=json_utils.get_json_type))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python.eager import def_function
|
|||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.layers import core as non_keras_core
|
from tensorflow.python.layers import core as non_keras_core
|
||||||
from tensorflow.python.module import module
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -39,7 +40,6 @@ from tensorflow.python.training.tracking import data_structures
|
|||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.training.tracking import util
|
from tensorflow.python.training.tracking import util
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import serialization
|
|
||||||
|
|
||||||
|
|
||||||
class ListTests(test.TestCase):
|
class ListTests(test.TestCase):
|
||||||
@ -47,7 +47,7 @@ class ListTests(test.TestCase):
|
|||||||
def testJSONSerialization(self):
|
def testJSONSerialization(self):
|
||||||
obj = tracking.AutoTrackable()
|
obj = tracking.AutoTrackable()
|
||||||
obj.l = [1]
|
obj.l = [1]
|
||||||
json.dumps(obj.l, default=serialization.get_json_type)
|
json.dumps(obj.l, default=json_utils.get_json_type)
|
||||||
|
|
||||||
def testNotTrackable(self):
|
def testNotTrackable(self):
|
||||||
class NotTrackable(object):
|
class NotTrackable(object):
|
||||||
@ -337,7 +337,7 @@ class MappingTests(test.TestCase):
|
|||||||
def testJSONSerialization(self):
|
def testJSONSerialization(self):
|
||||||
obj = tracking.AutoTrackable()
|
obj = tracking.AutoTrackable()
|
||||||
obj.d = {"a": 2}
|
obj.d = {"a": 2}
|
||||||
json.dumps(obj.d, default=serialization.get_json_type)
|
json.dumps(obj.d, default=json_utils.get_json_type)
|
||||||
|
|
||||||
def testNoOverwrite(self):
|
def testNoOverwrite(self):
|
||||||
mapping = data_structures.Mapping()
|
mapping = data_structures.Mapping()
|
||||||
@ -519,7 +519,7 @@ class TupleTests(test.TestCase, parameterized.TestCase):
|
|||||||
def testJSONSerialization(self):
|
def testJSONSerialization(self):
|
||||||
obj = tracking.AutoTrackable()
|
obj = tracking.AutoTrackable()
|
||||||
obj.l = (1,)
|
obj.l = (1,)
|
||||||
json.dumps(obj.l, default=serialization.get_json_type)
|
json.dumps(obj.l, default=json_utils.get_json_type)
|
||||||
|
|
||||||
def testNonLayerVariables(self):
|
def testNonLayerVariables(self):
|
||||||
v = resource_variable_ops.ResourceVariable([1.])
|
v = resource_variable_ops.ResourceVariable([1.])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user