Fix incorrect usage of keras code in tf.

TF core shouldn't have any deps back to keras.

PiperOrigin-RevId: 342142506
Change-Id: If0964333195bce741aa631972e2bb39ec5a9c25e
This commit is contained in:
Scott Zhu 2020-11-12 15:21:13 -08:00 committed by TensorFlower Gardener
parent 9785379b74
commit f71ba34bba

View File

@ -30,7 +30,6 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.layers import core as non_keras_core from tensorflow.python.layers import core as non_keras_core
from tensorflow.python.module import module from tensorflow.python.module import module
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -40,6 +39,7 @@ from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import serialization
class ListTests(test.TestCase): class ListTests(test.TestCase):
@ -47,7 +47,7 @@ class ListTests(test.TestCase):
def testJSONSerialization(self): def testJSONSerialization(self):
obj = tracking.AutoTrackable() obj = tracking.AutoTrackable()
obj.l = [1] obj.l = [1]
json.dumps(obj.l, default=json_utils.get_json_type) json.dumps(obj.l, default=serialization.get_json_type)
def testNotTrackable(self): def testNotTrackable(self):
class NotTrackable(object): class NotTrackable(object):
@ -337,7 +337,7 @@ class MappingTests(test.TestCase):
def testJSONSerialization(self): def testJSONSerialization(self):
obj = tracking.AutoTrackable() obj = tracking.AutoTrackable()
obj.d = {"a": 2} obj.d = {"a": 2}
json.dumps(obj.d, default=json_utils.get_json_type) json.dumps(obj.d, default=serialization.get_json_type)
def testNoOverwrite(self): def testNoOverwrite(self):
mapping = data_structures.Mapping() mapping = data_structures.Mapping()
@ -519,7 +519,7 @@ class TupleTests(test.TestCase, parameterized.TestCase):
def testJSONSerialization(self): def testJSONSerialization(self):
obj = tracking.AutoTrackable() obj = tracking.AutoTrackable()
obj.l = (1,) obj.l = (1,)
json.dumps(obj.l, default=json_utils.get_json_type) json.dumps(obj.l, default=serialization.get_json_type)
def testNonLayerVariables(self): def testNonLayerVariables(self):
v = resource_variable_ops.ResourceVariable([1.]) v = resource_variable_ops.ResourceVariable([1.])