Move the serialization_test to keras/tests
PiperOrigin-RevId: 316697481 Change-Id: I918b29a976b166662acf9045f87512aef485441b
This commit is contained in:
		
							parent
							
								
									12ec80d239
								
							
						
					
					
						commit
						cf599cade1
					
				| @ -359,6 +359,21 @@ cuda_py_test( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_py_test( | ||||
|     name = "serialization_util_test", | ||||
|     size = "small", | ||||
|     srcs = ["serialization_util_test.py"], | ||||
|     python_version = "PY3", | ||||
|     deps = [ | ||||
|         "//tensorflow/python:client_testlib", | ||||
|         "//tensorflow/python:constant_op", | ||||
|         "//tensorflow/python:framework_test_lib", | ||||
|         "//tensorflow/python:util", | ||||
|         "//tensorflow/python/keras/engine", | ||||
|         "//tensorflow/python/keras/layers:core", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_py_test( | ||||
|     name = "temporal_sample_weights_correctness_test", | ||||
|     srcs = ["temporal_sample_weights_correctness_test.py"], | ||||
|  | ||||
							
								
								
									
										67
									
								
								tensorflow/python/keras/tests/serialization_util_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								tensorflow/python/keras/tests/serialization_util_test.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Tests for serialization functions.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import json | ||||
| 
 | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import test_util | ||||
| from tensorflow.python.keras.engine import input_layer | ||||
| from tensorflow.python.keras.engine import sequential | ||||
| from tensorflow.python.keras.engine import training | ||||
| from tensorflow.python.keras.layers import core | ||||
| from tensorflow.python.platform import test | ||||
| from tensorflow.python.util import serialization | ||||
| 
 | ||||
| 
 | ||||
| class SerializationTests(test.TestCase): | ||||
| 
 | ||||
|   def test_serialize_dense(self): | ||||
|     dense = core.Dense(3) | ||||
|     dense(constant_op.constant([[4.]])) | ||||
|     round_trip = json.loads(json.dumps( | ||||
|         dense, default=serialization.get_json_type)) | ||||
|     self.assertEqual(3, round_trip["config"]["units"]) | ||||
| 
 | ||||
|   @test_util.run_in_graph_and_eager_modes | ||||
|   def test_serialize_sequential(self): | ||||
|     model = sequential.Sequential() | ||||
|     model.add(core.Dense(4)) | ||||
|     model.add(core.Dense(5)) | ||||
|     model(constant_op.constant([[1.]])) | ||||
|     sequential_round_trip = json.loads( | ||||
|         json.dumps(model, default=serialization.get_json_type)) | ||||
|     self.assertEqual( | ||||
|         # Note that `config['layers'][0]` will be an InputLayer in V2 | ||||
|         # (but not in V1) | ||||
|         5, sequential_round_trip["config"]["layers"][-1]["config"]["units"]) | ||||
| 
 | ||||
|   @test_util.run_in_graph_and_eager_modes | ||||
|   def test_serialize_model(self): | ||||
|     x = input_layer.Input(shape=[3]) | ||||
|     y = core.Dense(10)(x) | ||||
|     model = training.Model(x, y) | ||||
|     model(constant_op.constant([[1., 1., 1.]])) | ||||
|     model_round_trip = json.loads( | ||||
|         json.dumps(model, default=serialization.get_json_type)) | ||||
|     self.assertEqual( | ||||
|         10, model_round_trip["config"]["layers"][1]["config"]["units"]) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   test.main() | ||||
| @ -20,26 +20,13 @@ from __future__ import print_function | ||||
| 
 | ||||
| import json | ||||
| 
 | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import tensor_shape | ||||
| from tensorflow.python.framework import test_util | ||||
| from tensorflow.python.keras.engine import input_layer | ||||
| from tensorflow.python.keras.engine import sequential | ||||
| from tensorflow.python.keras.engine import training | ||||
| from tensorflow.python.keras.layers import core | ||||
| from tensorflow.python.platform import test | ||||
| from tensorflow.python.util import serialization | ||||
| 
 | ||||
| 
 | ||||
| class SerializationTests(test.TestCase): | ||||
| 
 | ||||
|   def test_serialize_dense(self): | ||||
|     dense = core.Dense(3) | ||||
|     dense(constant_op.constant([[4.]])) | ||||
|     round_trip = json.loads(json.dumps( | ||||
|         dense, default=serialization.get_json_type)) | ||||
|     self.assertEqual(3, round_trip["config"]["units"]) | ||||
| 
 | ||||
|   def test_serialize_shape(self): | ||||
|     round_trip = json.loads(json.dumps( | ||||
|         tensor_shape.TensorShape([None, 2, 3]), | ||||
| @ -47,29 +34,6 @@ class SerializationTests(test.TestCase): | ||||
|     self.assertIs(round_trip[0], None) | ||||
|     self.assertEqual(round_trip[1], 2) | ||||
| 
 | ||||
|   @test_util.run_in_graph_and_eager_modes | ||||
|   def test_serialize_sequential(self): | ||||
|     model = sequential.Sequential() | ||||
|     model.add(core.Dense(4)) | ||||
|     model.add(core.Dense(5)) | ||||
|     model(constant_op.constant([[1.]])) | ||||
|     sequential_round_trip = json.loads( | ||||
|         json.dumps(model, default=serialization.get_json_type)) | ||||
|     self.assertEqual( | ||||
|         # Note that `config['layers'][0]` will be an InputLayer in V2 | ||||
|         # (but not in V1) | ||||
|         5, sequential_round_trip["config"]["layers"][-1]["config"]["units"]) | ||||
| 
 | ||||
|   @test_util.run_in_graph_and_eager_modes | ||||
|   def test_serialize_model(self): | ||||
|     x = input_layer.Input(shape=[3]) | ||||
|     y = core.Dense(10)(x) | ||||
|     model = training.Model(x, y) | ||||
|     model(constant_op.constant([[1., 1., 1.]])) | ||||
|     model_round_trip = json.loads( | ||||
|         json.dumps(model, default=serialization.get_json_type)) | ||||
|     self.assertEqual( | ||||
|         10, model_round_trip["config"]["layers"][1]["config"]["units"]) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   test.main() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user