Add get_layer_policy() function.
tf.keras.mixed_precision.experimental.get_layer_policy() takes in a layer and returns its dtype policy. PiperOrigin-RevId: 288031274 Change-Id: Id31848fd1e03fc324c8745305b5df3ee67b5f5a4
This commit is contained in:
parent
90093f0f82
commit
67330f736e
tensorflow
python/keras
tools/api/golden
@ -64,6 +64,7 @@ keras_packages = [
|
||||
"tensorflow.python.keras.layers.wrappers",
|
||||
"tensorflow.python.keras.losses",
|
||||
"tensorflow.python.keras.metrics",
|
||||
"tensorflow.python.keras.mixed_precision.experimental.get_layer_policy",
|
||||
"tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer",
|
||||
"tensorflow.python.keras.mixed_precision.experimental.policy",
|
||||
"tensorflow.python.keras.models",
|
||||
|
@ -31,6 +31,7 @@ py_library(
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":get_layer_policy",
|
||||
":loss_scale_optimizer",
|
||||
":policy",
|
||||
],
|
||||
@ -88,6 +89,27 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "get_layer_policy",
|
||||
srcs = ["get_layer_policy.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/keras:base_layer",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "get_layer_policy_test",
|
||||
srcs = ["get_layer_policy_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":get_layer_policy",
|
||||
":policy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras:layers",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "autocast_variable",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains the get_layer_policy function.
|
||||
|
||||
This is a separate file from policy.py to avoid a circular dependency.
|
||||
get_layer_policy() relies on base_layer.py, itself which relies on policy.py.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.get_layer_policy')
|
||||
def get_layer_policy(layer):
|
||||
"""Returns the dtype policy of a layer.
|
||||
|
||||
Args:
|
||||
layer: A `tf.keras.layers.Layer`.
|
||||
|
||||
Returns:
|
||||
The `tf.keras.mixed_precision.experimental.Policy` of the layer.
|
||||
"""
|
||||
if not isinstance(layer, base_layer.Layer):
|
||||
raise ValueError('get_policy can only be called on a layer, but got: %s'
|
||||
% (layer,))
|
||||
return layer._dtype_policy # pylint: disable=protected-access
|
@ -0,0 +1,49 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests the get_layer_policy function."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GetLayerPolicyTest(test.TestCase):
|
||||
|
||||
def test_get_layer_policy(self):
|
||||
layer = core.Dense(4)
|
||||
self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float32')
|
||||
|
||||
p = policy.Policy('mixed_float16')
|
||||
layer = core.Dense(4, dtype=p)
|
||||
self.assertIs(get_layer_policy.get_layer_policy(layer), p)
|
||||
|
||||
layer = core.Dense(4, dtype='float64')
|
||||
self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float64')
|
||||
|
||||
def test_error(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'get_policy can only be called on a layer, but got: 1'):
|
||||
get_layer_policy.get_layer_policy(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
base_layer_utils.enable_v2_dtype_behavior()
|
||||
test.main()
|
@ -41,6 +41,7 @@ from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
|
||||
@ -121,12 +122,14 @@ class KerasLayerTest(keras_parameterized.TestCase):
|
||||
with strategy_fn().scope(), policy.policy_scope(policy_name):
|
||||
layer = mp_test_util.AddLayer(assert_type=dtype)
|
||||
self.assertEqual(layer.dtype, dtypes.float32)
|
||||
self.assertEqual(layer._dtype_policy._name, policy_name)
|
||||
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
|
||||
policy_name)
|
||||
y = layer(x)
|
||||
self.assertEqual(layer.v.dtype, dtypes.float32)
|
||||
self.assertEqual(y.dtype, dtype)
|
||||
self.assertEqual(layer.dtype, dtypes.float32)
|
||||
self.assertEqual(layer._dtype_policy._name, policy_name)
|
||||
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
|
||||
policy_name)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(self.evaluate(y), 2.)
|
||||
|
||||
@ -581,7 +584,8 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
|
||||
# Ensure various dtype-related aspects of the layer are correct
|
||||
self.assertEqual(layer.dtype, 'float32')
|
||||
self.assertEqual(layer._dtype_policy.name, 'mixed_float16')
|
||||
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
|
||||
'mixed_float16')
|
||||
self.assertEqual(layer.v.dtype, 'float32')
|
||||
self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')
|
||||
|
||||
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "Policy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_layer_policy"
|
||||
argspec: "args=[\'layer\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "global_policy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "Policy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_layer_policy"
|
||||
argspec: "args=[\'layer\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "global_policy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user