Added reverse dependency back to Keras.
Also the original code also passes since the saving code has been updated recently. PiperOrigin-RevId: 303782079 Change-Id: I0ec9ca84d6584aa30bb5fc3c51906afc7de4320d
This commit is contained in:
parent
ad331f9797
commit
39b77be0ed
@ -23,7 +23,6 @@ import wrapt
|
|||||||
|
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.util.compat import collections_abc
|
from tensorflow.python.util.compat import collections_abc
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
|
||||||
|
|
||||||
|
|
||||||
def get_json_type(obj):
|
def get_json_type(obj):
|
||||||
@ -41,10 +40,7 @@ def get_json_type(obj):
|
|||||||
# if obj is a serializable Keras class instance
|
# if obj is a serializable Keras class instance
|
||||||
# e.g. optimizer, layer
|
# e.g. optimizer, layer
|
||||||
if hasattr(obj, 'get_config'):
|
if hasattr(obj, 'get_config'):
|
||||||
return {
|
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
|
||||||
'class_name': generic_utils.get_registered_name(obj.__class__),
|
|
||||||
'config': obj.get_config()
|
|
||||||
}
|
|
||||||
|
|
||||||
# if obj is any numpy type
|
# if obj is any numpy type
|
||||||
if type(obj).__module__ == np.__name__:
|
if type(obj).__module__ == np.__name__:
|
||||||
|
@ -23,12 +23,10 @@ import json
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import losses
|
|
||||||
from tensorflow.python.keras.engine import input_layer
|
from tensorflow.python.keras.engine import input_layer
|
||||||
from tensorflow.python.keras.engine import sequential
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.utils import losses_utils, generic_utils
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import serialization
|
from tensorflow.python.util import serialization
|
||||||
|
|
||||||
@ -71,41 +69,5 @@ class SerializationTests(test.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def test_serialize_custom_model_compile(self):
|
|
||||||
with generic_utils.custom_object_scope():
|
|
||||||
|
|
||||||
@generic_utils.register_keras_serializable(package="dummy-package")
|
|
||||||
class DummySparseCategoricalCrossentropyLoss(losses.LossFunctionWrapper):
|
|
||||||
# This loss is identical equal to tf.keras.losses.SparseCategoricalCrossentropy
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
from_logits=False,
|
|
||||||
reduction=losses_utils.ReductionV2.AUTO,
|
|
||||||
name="dummy_sparse_categorical_crossentropy_loss",
|
|
||||||
):
|
|
||||||
super(DummySparseCategoricalCrossentropyLoss, self).__init__(
|
|
||||||
losses.sparse_categorical_crossentropy,
|
|
||||||
name=name,
|
|
||||||
reduction=reduction,
|
|
||||||
from_logits=from_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = input_layer.Input(shape=[3])
|
|
||||||
y = core.Dense(10)(x)
|
|
||||||
model = training.Model(x, y)
|
|
||||||
model.compile(
|
|
||||||
loss=DummySparseCategoricalCrossentropyLoss(from_logits=True))
|
|
||||||
model_round_trip = json.loads(
|
|
||||||
json.dumps(model.loss, default=serialization.get_json_type))
|
|
||||||
|
|
||||||
# check if class name with package scope
|
|
||||||
self.assertEqual("dummy-package>DummySparseCategoricalCrossentropyLoss",
|
|
||||||
model_round_trip["class_name"])
|
|
||||||
|
|
||||||
# check if configure is correctly
|
|
||||||
self.assertEqual(True, model_round_trip["config"]["from_logits"])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user