Fix incorrect usage of keras code in tf.
TF core shouldn't have any deps back to keras. PiperOrigin-RevId: 342142506 Change-Id: If0964333195bce741aa631972e2bb39ec5a9c25e
This commit is contained in:
parent
9785379b74
commit
f71ba34bba
@ -30,7 +30,6 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.layers import core as non_keras_core
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -40,6 +39,7 @@ from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
class ListTests(test.TestCase):
|
||||
@ -47,7 +47,7 @@ class ListTests(test.TestCase):
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.l = [1]
|
||||
json.dumps(obj.l, default=json_utils.get_json_type)
|
||||
json.dumps(obj.l, default=serialization.get_json_type)
|
||||
|
||||
def testNotTrackable(self):
|
||||
class NotTrackable(object):
|
||||
@ -337,7 +337,7 @@ class MappingTests(test.TestCase):
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.d = {"a": 2}
|
||||
json.dumps(obj.d, default=json_utils.get_json_type)
|
||||
json.dumps(obj.d, default=serialization.get_json_type)
|
||||
|
||||
def testNoOverwrite(self):
|
||||
mapping = data_structures.Mapping()
|
||||
@ -519,7 +519,7 @@ class TupleTests(test.TestCase, parameterized.TestCase):
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.l = (1,)
|
||||
json.dumps(obj.l, default=json_utils.get_json_type)
|
||||
json.dumps(obj.l, default=serialization.get_json_type)
|
||||
|
||||
def testNonLayerVariables(self):
|
||||
v = resource_variable_ops.ResourceVariable([1.])
|
||||
|
Loading…
x
Reference in New Issue
Block a user