Update Keras tests under engine package to use combinations.
1. Change all test_util.run_all_in_graph_and_eager_modes to combination. 2. Replace import tensorflow.python.keras with explicit module import. 3. Update BUILD file to not rely on the overall Keras target. PiperOrigin-RevId: 300434594 Change-Id: I7ca776f66ac3097ba1343a5c58969ec2e6e0df3d
This commit is contained in:
parent
690fa965f9
commit
00fbbd9036
@ -183,7 +183,10 @@ tf_py_test(
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
],
|
||||
)
|
||||
|
||||
@ -437,10 +440,28 @@ tf_py_test(
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":base_layer",
|
||||
":engine",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras:models",
|
||||
"//tensorflow/python/keras:testing_utils",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//tensorflow/python/keras/utils:layer_utils",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -455,10 +476,36 @@ tf_py_test(
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":base_layer",
|
||||
":engine",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//tensorflow/python/keras:regularizers",
|
||||
"//tensorflow/python/keras:testing_utils",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:policy",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,13 +1,13 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@ -17,15 +17,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TrackableWeightHandlerTest(keras_parameterized.TestCase):
|
||||
|
||||
def get_table_handler(self):
|
||||
@ -44,7 +44,7 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
|
||||
def test_get_and_set_weights(self):
|
||||
table_handler = self.get_table_handler()
|
||||
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_data = {b'a': 1, b'b': 2, b'c': 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
@ -54,7 +54,7 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_get_and_set_weights_does_not_add_ops(self):
|
||||
table_handler = self.get_table_handler()
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_data = {b'a': 1, b'b': 2, b'c': 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
@ -66,5 +66,5 @@ class TrackableWeightHandlerTest(keras_parameterized.TestCase):
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user