Add get_layer_policy() function.
tf.keras.mixed_precision.experimental.get_layer_policy() takes in a layer and returns its dtype policy. PiperOrigin-RevId: 288031274 Change-Id: Id31848fd1e03fc324c8745305b5df3ee67b5f5a4
This commit is contained in:
parent
90093f0f82
commit
67330f736e
@ -64,6 +64,7 @@ keras_packages = [
|
|||||||
"tensorflow.python.keras.layers.wrappers",
|
"tensorflow.python.keras.layers.wrappers",
|
||||||
"tensorflow.python.keras.losses",
|
"tensorflow.python.keras.losses",
|
||||||
"tensorflow.python.keras.metrics",
|
"tensorflow.python.keras.metrics",
|
||||||
|
"tensorflow.python.keras.mixed_precision.experimental.get_layer_policy",
|
||||||
"tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer",
|
"tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer",
|
||||||
"tensorflow.python.keras.mixed_precision.experimental.policy",
|
"tensorflow.python.keras.mixed_precision.experimental.policy",
|
||||||
"tensorflow.python.keras.models",
|
"tensorflow.python.keras.models",
|
||||||
|
@ -31,6 +31,7 @@ py_library(
|
|||||||
srcs = ["__init__.py"],
|
srcs = ["__init__.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":get_layer_policy",
|
||||||
":loss_scale_optimizer",
|
":loss_scale_optimizer",
|
||||||
":policy",
|
":policy",
|
||||||
],
|
],
|
||||||
@ -88,6 +89,27 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "get_layer_policy",
|
||||||
|
srcs = ["get_layer_policy.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python/keras:base_layer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "get_layer_policy_test",
|
||||||
|
srcs = ["get_layer_policy_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":get_layer_policy",
|
||||||
|
":policy",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/keras:layers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "autocast_variable",
|
name = "autocast_variable",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Contains the get_layer_policy function.
|
||||||
|
|
||||||
|
This is a separate file from policy.py to avoid a circular dependency.
|
||||||
|
get_layer_policy() relies on base_layer.py, itself which relies on policy.py.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.keras.engine import base_layer
|
||||||
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
|
@keras_export('keras.mixed_precision.experimental.get_layer_policy')
|
||||||
|
def get_layer_policy(layer):
|
||||||
|
"""Returns the dtype policy of a layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: A `tf.keras.layers.Layer`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The `tf.keras.mixed_precision.experimental.Policy` of the layer.
|
||||||
|
"""
|
||||||
|
if not isinstance(layer, base_layer.Layer):
|
||||||
|
raise ValueError('get_policy can only be called on a layer, but got: %s'
|
||||||
|
% (layer,))
|
||||||
|
return layer._dtype_policy # pylint: disable=protected-access
|
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests the get_layer_policy function."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
|
from tensorflow.python.keras.layers import core
|
||||||
|
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
|
||||||
|
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class GetLayerPolicyTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_get_layer_policy(self):
|
||||||
|
layer = core.Dense(4)
|
||||||
|
self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float32')
|
||||||
|
|
||||||
|
p = policy.Policy('mixed_float16')
|
||||||
|
layer = core.Dense(4, dtype=p)
|
||||||
|
self.assertIs(get_layer_policy.get_layer_policy(layer), p)
|
||||||
|
|
||||||
|
layer = core.Dense(4, dtype='float64')
|
||||||
|
self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float64')
|
||||||
|
|
||||||
|
def test_error(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'get_policy can only be called on a layer, but got: 1'):
|
||||||
|
get_layer_policy.get_layer_policy(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
base_layer_utils.enable_v2_dtype_behavior()
|
||||||
|
test.main()
|
@ -41,6 +41,7 @@ from tensorflow.python.keras.engine import base_layer
|
|||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.engine import input_spec
|
from tensorflow.python.keras.engine import input_spec
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
|
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
|
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
|
||||||
@ -121,12 +122,14 @@ class KerasLayerTest(keras_parameterized.TestCase):
|
|||||||
with strategy_fn().scope(), policy.policy_scope(policy_name):
|
with strategy_fn().scope(), policy.policy_scope(policy_name):
|
||||||
layer = mp_test_util.AddLayer(assert_type=dtype)
|
layer = mp_test_util.AddLayer(assert_type=dtype)
|
||||||
self.assertEqual(layer.dtype, dtypes.float32)
|
self.assertEqual(layer.dtype, dtypes.float32)
|
||||||
self.assertEqual(layer._dtype_policy._name, policy_name)
|
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
|
||||||
|
policy_name)
|
||||||
y = layer(x)
|
y = layer(x)
|
||||||
self.assertEqual(layer.v.dtype, dtypes.float32)
|
self.assertEqual(layer.v.dtype, dtypes.float32)
|
||||||
self.assertEqual(y.dtype, dtype)
|
self.assertEqual(y.dtype, dtype)
|
||||||
self.assertEqual(layer.dtype, dtypes.float32)
|
self.assertEqual(layer.dtype, dtypes.float32)
|
||||||
self.assertEqual(layer._dtype_policy._name, policy_name)
|
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
|
||||||
|
policy_name)
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertEqual(self.evaluate(y), 2.)
|
self.assertEqual(self.evaluate(y), 2.)
|
||||||
|
|
||||||
@ -581,7 +584,8 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
# Ensure various dtype-related aspects of the layer are correct
|
# Ensure various dtype-related aspects of the layer are correct
|
||||||
self.assertEqual(layer.dtype, 'float32')
|
self.assertEqual(layer.dtype, 'float32')
|
||||||
self.assertEqual(layer._dtype_policy.name, 'mixed_float16')
|
self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
|
||||||
|
'mixed_float16')
|
||||||
self.assertEqual(layer.v.dtype, 'float32')
|
self.assertEqual(layer.v.dtype, 'float32')
|
||||||
self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')
|
self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')
|
||||||
|
|
||||||
|
@ -8,6 +8,10 @@ tf_module {
|
|||||||
name: "Policy"
|
name: "Policy"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_layer_policy"
|
||||||
|
argspec: "args=[\'layer\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "global_policy"
|
name: "global_policy"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -8,6 +8,10 @@ tf_module {
|
|||||||
name: "Policy"
|
name: "Policy"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_layer_policy"
|
||||||
|
argspec: "args=[\'layer\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "global_policy"
|
name: "global_policy"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user