Add get_layer_policy() function.

tf.keras.mixed_precision.experimental.get_layer_policy() takes in a layer and returns its dtype policy.

PiperOrigin-RevId: 288031274
Change-Id: Id31848fd1e03fc324c8745305b5df3ee67b5f5a4
This commit is contained in:
Reed Wanderman-Milne 2020-01-03 12:28:20 -08:00 committed by TensorFlower Gardener
parent 90093f0f82
commit 67330f736e
7 changed files with 128 additions and 3 deletions

View File

@ -64,6 +64,7 @@ keras_packages = [
"tensorflow.python.keras.layers.wrappers",
"tensorflow.python.keras.losses",
"tensorflow.python.keras.metrics",
"tensorflow.python.keras.mixed_precision.experimental.get_layer_policy",
"tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer",
"tensorflow.python.keras.mixed_precision.experimental.policy",
"tensorflow.python.keras.models",

View File

@ -31,6 +31,7 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":get_layer_policy",
":loss_scale_optimizer",
":policy",
],
@ -88,6 +89,27 @@ cuda_py_test(
],
)
py_library(
name = "get_layer_policy",
srcs = ["get_layer_policy.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/keras:base_layer",
],
)
py_test(
name = "get_layer_policy_test",
srcs = ["get_layer_policy_test.py"],
srcs_version = "PY2AND3",
deps = [
":get_layer_policy",
":policy",
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras:layers",
],
)
py_library(
name = "autocast_variable",
srcs = [

View File

@ -0,0 +1,41 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the get_layer_policy function.
This is a separate file from policy.py to avoid a circular dependency.
get_layer_policy() relies on base_layer.py, itself which relies on policy.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.mixed_precision.experimental.get_layer_policy')
def get_layer_policy(layer):
"""Returns the dtype policy of a layer.
Args:
layer: A `tf.keras.layers.Layer`.
Returns:
The `tf.keras.mixed_precision.experimental.Policy` of the layer.
"""
if not isinstance(layer, base_layer.Layer):
raise ValueError('get_policy can only be called on a layer, but got: %s'
% (layer,))
return layer._dtype_policy # pylint: disable=protected-access

View File

@ -0,0 +1,49 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests the get_layer_policy function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.platform import test
class GetLayerPolicyTest(test.TestCase):
def test_get_layer_policy(self):
layer = core.Dense(4)
self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float32')
p = policy.Policy('mixed_float16')
layer = core.Dense(4, dtype=p)
self.assertIs(get_layer_policy.get_layer_policy(layer), p)
layer = core.Dense(4, dtype='float64')
self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float64')
def test_error(self):
with self.assertRaisesRegexp(
ValueError, 'get_policy can only be called on a layer, but got: 1'):
get_layer_policy.get_layer_policy(1)
if __name__ == '__main__':
base_layer_utils.enable_v2_dtype_behavior()
test.main()

View File

@ -41,6 +41,7 @@ from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
@ -121,12 +122,14 @@ class KerasLayerTest(keras_parameterized.TestCase):
with strategy_fn().scope(), policy.policy_scope(policy_name):
layer = mp_test_util.AddLayer(assert_type=dtype)
self.assertEqual(layer.dtype, dtypes.float32)
self.assertEqual(layer._dtype_policy._name, policy_name)
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
policy_name)
y = layer(x)
self.assertEqual(layer.v.dtype, dtypes.float32)
self.assertEqual(y.dtype, dtype)
self.assertEqual(layer.dtype, dtypes.float32)
self.assertEqual(layer._dtype_policy._name, policy_name)
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
policy_name)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(y), 2.)
@ -581,7 +584,8 @@ class KerasModelTest(keras_parameterized.TestCase):
# Ensure various dtype-related aspects of the layer are correct
self.assertEqual(layer.dtype, 'float32')
self.assertEqual(layer._dtype_policy.name, 'mixed_float16')
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
'mixed_float16')
self.assertEqual(layer.v.dtype, 'float32')
self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')

View File

@ -8,6 +8,10 @@ tf_module {
name: "Policy"
mtype: "<type \'type\'>"
}
member_method {
name: "get_layer_policy"
argspec: "args=[\'layer\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "global_policy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"

View File

@ -8,6 +8,10 @@ tf_module {
name: "Policy"
mtype: "<type \'type\'>"
}
member_method {
name: "get_layer_policy"
argspec: "args=[\'layer\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "global_policy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"