Added reverse dependency back to Keras.
Also the original code also passes since the saving code has been updated recently. PiperOrigin-RevId: 303782079 Change-Id: I0ec9ca84d6584aa30bb5fc3c51906afc7de4320d
This commit is contained in:
parent
ad331f9797
commit
39b77be0ed
tensorflow/python/util
@ -23,7 +23,6 @@ import wrapt
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
|
||||
|
||||
def get_json_type(obj):
|
||||
@ -41,10 +40,7 @@ def get_json_type(obj):
|
||||
# if obj is a serializable Keras class instance
|
||||
# e.g. optimizer, layer
|
||||
if hasattr(obj, 'get_config'):
|
||||
return {
|
||||
'class_name': generic_utils.get_registered_name(obj.__class__),
|
||||
'config': obj.get_config()
|
||||
}
|
||||
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
|
||||
|
||||
# if obj is any numpy type
|
||||
if type(obj).__module__ == np.__name__:
|
||||
|
@ -23,12 +23,10 @@ import json
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.utils import losses_utils, generic_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
@ -71,41 +69,5 @@ class SerializationTests(test.TestCase):
|
||||
self.assertEqual(
|
||||
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_serialize_custom_model_compile(self):
|
||||
with generic_utils.custom_object_scope():
|
||||
|
||||
@generic_utils.register_keras_serializable(package="dummy-package")
|
||||
class DummySparseCategoricalCrossentropyLoss(losses.LossFunctionWrapper):
|
||||
# This loss is identical equal to tf.keras.losses.SparseCategoricalCrossentropy
|
||||
def __init__(
|
||||
self,
|
||||
from_logits=False,
|
||||
reduction=losses_utils.ReductionV2.AUTO,
|
||||
name="dummy_sparse_categorical_crossentropy_loss",
|
||||
):
|
||||
super(DummySparseCategoricalCrossentropyLoss, self).__init__(
|
||||
losses.sparse_categorical_crossentropy,
|
||||
name=name,
|
||||
reduction=reduction,
|
||||
from_logits=from_logits,
|
||||
)
|
||||
|
||||
x = input_layer.Input(shape=[3])
|
||||
y = core.Dense(10)(x)
|
||||
model = training.Model(x, y)
|
||||
model.compile(
|
||||
loss=DummySparseCategoricalCrossentropyLoss(from_logits=True))
|
||||
model_round_trip = json.loads(
|
||||
json.dumps(model.loss, default=serialization.get_json_type))
|
||||
|
||||
# check if class name with package scope
|
||||
self.assertEqual("dummy-package>DummySparseCategoricalCrossentropyLoss",
|
||||
model_round_trip["class_name"])
|
||||
|
||||
# check if configure is correctly
|
||||
self.assertEqual(True, model_round_trip["config"]["from_logits"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user