Added reverse dependency back to Keras.

Also the original code also passes since the saving code has been updated recently.

PiperOrigin-RevId: 303782079
Change-Id: I0ec9ca84d6584aa30bb5fc3c51906afc7de4320d
This commit is contained in:
Scott Zhu 2020-03-30 11:20:00 -07:00 committed by TensorFlower Gardener
parent ad331f9797
commit 39b77be0ed
2 changed files with 1 additions and 43 deletions

View File

@ -23,7 +23,6 @@ import wrapt
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.keras.utils import generic_utils
def get_json_type(obj):
@ -41,10 +40,7 @@ def get_json_type(obj):
# if obj is a serializable Keras class instance
# e.g. optimizer, layer
if hasattr(obj, 'get_config'):
return {
'class_name': generic_utils.get_registered_name(obj.__class__),
'config': obj.get_config()
}
return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
# if obj is any numpy type
if type(obj).__module__ == np.__name__:

View File

@ -23,12 +23,10 @@ import json
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras import losses
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.utils import losses_utils, generic_utils
from tensorflow.python.platform import test
from tensorflow.python.util import serialization
@ -71,41 +69,5 @@ class SerializationTests(test.TestCase):
self.assertEqual(
10, model_round_trip["config"]["layers"][1]["config"]["units"])
@test_util.run_in_graph_and_eager_modes
def test_serialize_custom_model_compile(self):
with generic_utils.custom_object_scope():
@generic_utils.register_keras_serializable(package="dummy-package")
class DummySparseCategoricalCrossentropyLoss(losses.LossFunctionWrapper):
# This loss is identical equal to tf.keras.losses.SparseCategoricalCrossentropy
def __init__(
self,
from_logits=False,
reduction=losses_utils.ReductionV2.AUTO,
name="dummy_sparse_categorical_crossentropy_loss",
):
super(DummySparseCategoricalCrossentropyLoss, self).__init__(
losses.sparse_categorical_crossentropy,
name=name,
reduction=reduction,
from_logits=from_logits,
)
x = input_layer.Input(shape=[3])
y = core.Dense(10)(x)
model = training.Model(x, y)
model.compile(
loss=DummySparseCategoricalCrossentropyLoss(from_logits=True))
model_round_trip = json.loads(
json.dumps(model.loss, default=serialization.get_json_type))
# check if class name with package scope
self.assertEqual("dummy-package>DummySparseCategoricalCrossentropyLoss",
model_round_trip["class_name"])
# check if configure is correctly
self.assertEqual(True, model_round_trip["config"]["from_logits"])
if __name__ == "__main__":
test.main()