Move the serialization_test to keras/tests
PiperOrigin-RevId: 316697481 Change-Id: I918b29a976b166662acf9045f87512aef485441b
This commit is contained in:
parent
12ec80d239
commit
cf599cade1
@ -359,6 +359,21 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "serialization_util_test",
|
||||
size = "small",
|
||||
srcs = ["serialization_util_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "temporal_sample_weights_correctness_test",
|
||||
srcs = ["temporal_sample_weights_correctness_test.py"],
|
||||
|
||||
67
tensorflow/python/keras/tests/serialization_util_test.py
Normal file
67
tensorflow/python/keras/tests/serialization_util_test.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for serialization functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
class SerializationTests(test.TestCase):
|
||||
|
||||
def test_serialize_dense(self):
|
||||
dense = core.Dense(3)
|
||||
dense(constant_op.constant([[4.]]))
|
||||
round_trip = json.loads(json.dumps(
|
||||
dense, default=serialization.get_json_type))
|
||||
self.assertEqual(3, round_trip["config"]["units"])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_serialize_sequential(self):
|
||||
model = sequential.Sequential()
|
||||
model.add(core.Dense(4))
|
||||
model.add(core.Dense(5))
|
||||
model(constant_op.constant([[1.]]))
|
||||
sequential_round_trip = json.loads(
|
||||
json.dumps(model, default=serialization.get_json_type))
|
||||
self.assertEqual(
|
||||
# Note that `config['layers'][0]` will be an InputLayer in V2
|
||||
# (but not in V1)
|
||||
5, sequential_round_trip["config"]["layers"][-1]["config"]["units"])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_serialize_model(self):
|
||||
x = input_layer.Input(shape=[3])
|
||||
y = core.Dense(10)(x)
|
||||
model = training.Model(x, y)
|
||||
model(constant_op.constant([[1., 1., 1.]]))
|
||||
model_round_trip = json.loads(
|
||||
json.dumps(model, default=serialization.get_json_type))
|
||||
self.assertEqual(
|
||||
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
@ -20,26 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
class SerializationTests(test.TestCase):
|
||||
|
||||
def test_serialize_dense(self):
|
||||
dense = core.Dense(3)
|
||||
dense(constant_op.constant([[4.]]))
|
||||
round_trip = json.loads(json.dumps(
|
||||
dense, default=serialization.get_json_type))
|
||||
self.assertEqual(3, round_trip["config"]["units"])
|
||||
|
||||
def test_serialize_shape(self):
|
||||
round_trip = json.loads(json.dumps(
|
||||
tensor_shape.TensorShape([None, 2, 3]),
|
||||
@ -47,29 +34,6 @@ class SerializationTests(test.TestCase):
|
||||
self.assertIs(round_trip[0], None)
|
||||
self.assertEqual(round_trip[1], 2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_serialize_sequential(self):
|
||||
model = sequential.Sequential()
|
||||
model.add(core.Dense(4))
|
||||
model.add(core.Dense(5))
|
||||
model(constant_op.constant([[1.]]))
|
||||
sequential_round_trip = json.loads(
|
||||
json.dumps(model, default=serialization.get_json_type))
|
||||
self.assertEqual(
|
||||
# Note that `config['layers'][0]` will be an InputLayer in V2
|
||||
# (but not in V1)
|
||||
5, sequential_round_trip["config"]["layers"][-1]["config"]["units"])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_serialize_model(self):
|
||||
x = input_layer.Input(shape=[3])
|
||||
y = core.Dense(10)(x)
|
||||
model = training.Model(x, y)
|
||||
model(constant_op.constant([[1., 1., 1.]]))
|
||||
model_round_trip = json.loads(
|
||||
json.dumps(model, default=serialization.get_json_type))
|
||||
self.assertEqual(
|
||||
10, model_round_trip["config"]["layers"][1]["config"]["units"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user