Initial support for mixed precision policies.
A tf.keras.mixed_precision.experimental.Policy determines the dtype of layer computations and variables. In this initial implementation, policies only determine layer variable dtypes, and support for determining layer computation dtypes will come later. Co-authored-by: James Qin <jamesqin@google.com> PiperOrigin-RevId: 236010048
This commit is contained in:
parent
97e36fc65e
commit
f22e833a2b
tensorflow
python
tools/api/golden
v1
tensorflow.keras.mixed_precision.experimental.-policy.pbtxttensorflow.keras.mixed_precision.experimental.pbtxttensorflow.keras.mixed_precision.pbtxttensorflow.keras.pbtxttensorflow.layers.-average-pooling1-d.pbtxttensorflow.layers.-average-pooling2-d.pbtxttensorflow.layers.-average-pooling3-d.pbtxttensorflow.layers.-batch-normalization.pbtxttensorflow.layers.-conv1-d.pbtxttensorflow.layers.-conv2-d-transpose.pbtxttensorflow.layers.-conv2-d.pbtxttensorflow.layers.-conv3-d-transpose.pbtxttensorflow.layers.-conv3-d.pbtxttensorflow.layers.-dense.pbtxttensorflow.layers.-dropout.pbtxttensorflow.layers.-flatten.pbtxttensorflow.layers.-layer.pbtxttensorflow.layers.-max-pooling1-d.pbtxttensorflow.layers.-max-pooling2-d.pbtxttensorflow.layers.-max-pooling3-d.pbtxttensorflow.layers.-separable-conv1-d.pbtxttensorflow.layers.-separable-conv2-d.pbtxttensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxttensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxttensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxttensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxttensorflow.nn.rnn_cell.-device-wrapper.pbtxttensorflow.nn.rnn_cell.-dropout-wrapper.pbtxttensorflow.nn.rnn_cell.-g-r-u-cell.pbtxttensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxttensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxttensorflow.nn.rnn_cell.-r-n-n-cell.pbtxttensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
v2
tensorflow.keras.mixed_precision.experimental.-policy.pbtxttensorflow.keras.mixed_precision.experimental.pbtxttensorflow.keras.mixed_precision.pbtxttensorflow.keras.pbtxttensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxttensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxttensorflow.nn.-dropout-wrapper.pbtxttensorflow.nn.-residual-wrapper.pbtxttensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@ -64,6 +64,7 @@ py_library(
|
||||
":saving",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:autocast_variable",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:policy",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"//tensorflow/python/saved_model",
|
||||
"@keras_applications_archive//:keras_applications",
|
||||
@ -171,6 +172,8 @@ py_library(
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:autocast_variable",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:policy",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
"@six_archive//:six",
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python.distribute import values as distribute_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -38,6 +39,8 @@ from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
# A module that only depends on `keras.layers` import these from here.
|
||||
@ -172,7 +175,9 @@ class Layer(trackable.Trackable):
|
||||
# A dictionary that maps metric names to metric result tensors. The results
|
||||
# are the running averages of metric values over an epoch.
|
||||
self._metrics_tensors = {}
|
||||
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
|
||||
|
||||
self._set_dtype_and_policy(dtype)
|
||||
|
||||
self._call_fn_args = function_utils.fn_args(self.call)
|
||||
self._compute_previous_mask = ('mask' in self._call_fn_args or
|
||||
hasattr(self, 'compute_mask'))
|
||||
@ -308,10 +313,13 @@ class Layer(trackable.Trackable):
|
||||
shape = shape or ()
|
||||
# Validate optional keyword arguments.
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in ['getter', 'collections']:
|
||||
if kwarg not in ['getter', 'collections', 'experimental_autocast']:
|
||||
raise TypeError('Unknown keyword argument:', kwarg)
|
||||
getter = kwargs.pop('getter', None)
|
||||
collections = kwargs.pop('collections', None)
|
||||
# 'experimental_autocast' can be set to False by the caller to indicate an
|
||||
# AutoCastVariable should never be created.
|
||||
autocast = kwargs.pop('experimental_autocast', True)
|
||||
|
||||
if dtype is None:
|
||||
dtype = self.dtype or backend.floatx()
|
||||
@ -368,6 +376,12 @@ class Layer(trackable.Trackable):
|
||||
aggregation=aggregation)
|
||||
backend.track_variable(variable)
|
||||
|
||||
if autocast and self._mixed_precision_policy.should_cast_variables:
|
||||
if isinstance(variable, distribute_values.DistributedVariable):
|
||||
variable = autocast_variable.AutoCastDistributedVariable(variable)
|
||||
else:
|
||||
variable = autocast_variable.AutoCastVariable(variable)
|
||||
|
||||
if regularizer is not None:
|
||||
# TODO(fchollet): in the future, this should be handled at the
|
||||
# level of variable creation, and weight regularization losses
|
||||
@ -402,6 +416,7 @@ class Layer(trackable.Trackable):
|
||||
config['batch_input_shape'] = self._batch_input_shape
|
||||
if hasattr(self, 'dtype'):
|
||||
config['dtype'] = self.dtype
|
||||
# TODO(reedwm): Handle serializing self._mixed_precision_policy.
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
@ -588,8 +603,11 @@ class Layer(trackable.Trackable):
|
||||
kwargs['training'] = backend.learning_phase()
|
||||
if not self.dynamic:
|
||||
try:
|
||||
with base_layer_utils.AutoAddUpdates(self,
|
||||
inputs) as auto_updater:
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
input_list,
|
||||
self._mixed_precision_policy.should_cast_variables), (
|
||||
base_layer_utils.AutoAddUpdates(self,
|
||||
inputs)) as auto_updater:
|
||||
outputs = self.call(inputs, *args, **kwargs)
|
||||
auto_updater.set_outputs(outputs)
|
||||
|
||||
@ -636,7 +654,9 @@ class Layer(trackable.Trackable):
|
||||
# Eager execution on data tensors.
|
||||
with ops.name_scope(self._name_scope()):
|
||||
self._maybe_build(inputs)
|
||||
outputs = self.call(inputs, *args, **kwargs)
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
input_list, self._mixed_precision_policy.should_cast_variables):
|
||||
outputs = self.call(inputs, *args, **kwargs)
|
||||
self._handle_activity_regularization(inputs, outputs)
|
||||
self._set_mask_metadata(inputs, outputs, previous_mask)
|
||||
|
||||
@ -1328,6 +1348,24 @@ class Layer(trackable.Trackable):
|
||||
# Methods & attributes below are all private and only used by the framework. #
|
||||
##############################################################################
|
||||
|
||||
def _set_dtype_and_policy(self, dtype):
|
||||
"""Sets self._dtype and self._mixed_precision_policy."""
|
||||
if dtype:
|
||||
if isinstance(dtype, policy.Policy):
|
||||
self._mixed_precision_policy = dtype
|
||||
self._dtype = self._mixed_precision_policy.default_variable_dtype
|
||||
else:
|
||||
# If a non-policy dtype is passed, no casting should be done. So we use
|
||||
# the "infer" policy, which does no casting.
|
||||
self._mixed_precision_policy = policy.Policy('infer')
|
||||
self._dtype = dtypes.as_dtype(dtype).name
|
||||
else:
|
||||
self._mixed_precision_policy = policy.global_policy()
|
||||
# If the global policy has not been set, it will be an "infer" policy
|
||||
# without a default variable dtype, and so self._dtype will be None. In
|
||||
# that case, self._dtype will be set when the layer is built or called.
|
||||
self._dtype = self._mixed_precision_policy.default_variable_dtype
|
||||
|
||||
def _name_scope(self):
|
||||
return self.name
|
||||
|
||||
|
@ -514,3 +514,32 @@ class AutoAddUpdates(object):
|
||||
self.layer.add_update(list(unconditional_updates))
|
||||
if conditional_updates:
|
||||
self.layer.add_update(list(conditional_updates), inputs=self.inputs)
|
||||
|
||||
|
||||
def _get_var_read_dtype(input_list, should_cast):
|
||||
"""Gets the dtype that AutoCastVariables should be read in."""
|
||||
if should_cast and input_list and input_list[0].dtype.is_floating:
|
||||
return input_list[0].dtype.base_dtype
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def autocast_context_manager(input_list, should_cast):
|
||||
"""Returns a context manager to autocast AutoCastVariables.
|
||||
|
||||
Under this context manager, if `should_cast` is True, AutoCastVariables will
|
||||
be casted. If `should_cast` is False, AutoCastVariables will not be casted,
|
||||
which can be used to disable autocasting if nested under another
|
||||
call to `autocast_context_manager`.
|
||||
|
||||
Args:
|
||||
input_list: The inputs to the layer with the AutoCastVariables.
|
||||
should_cast: Whether AutoCastVariables should be casted.
|
||||
|
||||
Returns:
|
||||
A context manager to automatically cast AutoCastVariables.
|
||||
"""
|
||||
var_read_dtype = _get_var_read_dtype(input_list, should_cast)
|
||||
return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access
|
||||
var_read_dtype)
|
||||
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
@ -209,6 +210,12 @@ class Network(base_layer.Layer):
|
||||
self._trackable_saver = (
|
||||
trackable_utils.saver_with_op_caching(self))
|
||||
|
||||
# Networks do not need to do any casting of inputs or variables, because
|
||||
# each of its layers will handle casting through the layer's own
|
||||
# implementation. Therefore networks use the 'infer' policy, which does no
|
||||
# casting.
|
||||
self._mixed_precision_policy = policy.Policy('infer')
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _init_graph_network(self, inputs, outputs, name=None):
|
||||
self._call_convention = (base_layer_utils
|
||||
|
@ -1001,7 +1001,11 @@ class Dense(Layer):
|
||||
output_shape = shape[:-1] + [self.units]
|
||||
outputs.set_shape(output_shape)
|
||||
else:
|
||||
inputs = math_ops.cast(inputs, self.dtype)
|
||||
# Cast the inputs to self.dtype, which is the variable dtype. We do not
|
||||
# cast if `should_cast_variables` is True, as in that case the variable
|
||||
# will be automatically casted to inputs.dtype.
|
||||
if not self._mixed_precision_policy.should_cast_variables:
|
||||
inputs = math_ops.cast(inputs, self.dtype)
|
||||
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
|
||||
if self.use_bias:
|
||||
outputs = nn.bias_add(outputs, self.bias)
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -298,6 +299,14 @@ class CoreLayersTest(keras_parameterized.TestCase):
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(outputs.dtype, 'float32')
|
||||
|
||||
def test_dense_with_policy(self):
|
||||
inputs = ops.convert_to_tensor(
|
||||
np.random.randint(low=0, high=7, size=(2, 2)), dtype='float16')
|
||||
layer = keras.layers.Dense(5, dtype=policy.Policy('infer_float32_vars'))
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(outputs.dtype, 'float16')
|
||||
self.assertEqual(layer.kernel.dtype, 'float32')
|
||||
|
||||
def test_dense_regularization(self):
|
||||
layer = keras.layers.Dense(
|
||||
3,
|
||||
|
@ -334,7 +334,8 @@ class BatchNormalizationV2(Layer):
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint,
|
||||
trainable=True)
|
||||
trainable=True,
|
||||
experimental_autocast=False)
|
||||
else:
|
||||
self.gamma = None
|
||||
if self.fused:
|
||||
@ -349,7 +350,8 @@ class BatchNormalizationV2(Layer):
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint,
|
||||
trainable=True)
|
||||
trainable=True,
|
||||
experimental_autocast=False)
|
||||
else:
|
||||
self.beta = None
|
||||
if self.fused:
|
||||
@ -370,7 +372,8 @@ class BatchNormalizationV2(Layer):
|
||||
initializer=self.moving_mean_initializer,
|
||||
synchronization=tf_variables.VariableSynchronization.ON_READ,
|
||||
trainable=False,
|
||||
aggregation=tf_variables.VariableAggregation.MEAN)
|
||||
aggregation=tf_variables.VariableAggregation.MEAN,
|
||||
experimental_autocast=False)
|
||||
|
||||
self.moving_variance = self.add_weight(
|
||||
name='moving_variance',
|
||||
@ -379,7 +382,8 @@ class BatchNormalizationV2(Layer):
|
||||
initializer=self.moving_variance_initializer,
|
||||
synchronization=tf_variables.VariableSynchronization.ON_READ,
|
||||
trainable=False,
|
||||
aggregation=tf_variables.VariableAggregation.MEAN)
|
||||
aggregation=tf_variables.VariableAggregation.MEAN,
|
||||
experimental_autocast=False)
|
||||
|
||||
if self.renorm:
|
||||
# Create variables to maintain the moving mean and standard deviation.
|
||||
@ -390,6 +394,7 @@ class BatchNormalizationV2(Layer):
|
||||
# stack to be cleared. The nested ones use a `lambda` to set the desired
|
||||
# device and ignore any devices that may be set by the custom getter.
|
||||
def _renorm_variable(name, shape):
|
||||
"""Create a renorm variable."""
|
||||
var = self.add_weight(
|
||||
name=name,
|
||||
shape=shape,
|
||||
@ -397,7 +402,8 @@ class BatchNormalizationV2(Layer):
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
synchronization=tf_variables.VariableSynchronization.ON_READ,
|
||||
trainable=False,
|
||||
aggregation=tf_variables.VariableAggregation.MEAN)
|
||||
aggregation=tf_variables.VariableAggregation.MEAN,
|
||||
experimental_autocast=False)
|
||||
return var
|
||||
|
||||
with distribution_strategy_context.get_strategy(
|
||||
@ -958,7 +964,8 @@ class LayerNormalization(Layer):
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint,
|
||||
trainable=True)
|
||||
trainable=True,
|
||||
experimental_autocast=False)
|
||||
else:
|
||||
self.gamma = None
|
||||
|
||||
@ -969,7 +976,8 @@ class LayerNormalization(Layer):
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint,
|
||||
trainable=True)
|
||||
trainable=True,
|
||||
experimental_autocast=False)
|
||||
else:
|
||||
self.beta = None
|
||||
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers import normalization
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
@ -143,6 +144,19 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
||||
_run_batchnorm_correctness_test(
|
||||
normalization.BatchNormalization, dtype='float16', fused=False)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_batchnorm_policy(self):
|
||||
norm = keras.layers.BatchNormalization(
|
||||
axis=-1,
|
||||
input_shape=(4, 4, 3),
|
||||
momentum=0.8,
|
||||
dtype=policy.Policy('infer_float32_vars'))
|
||||
x = np.random.normal(size=(10, 4, 4, 3)).astype('float16')
|
||||
y = norm(x)
|
||||
self.assertEqual(y.dtype, 'float16')
|
||||
self.assertEqual(norm.beta.dtype.base_dtype, 'float32')
|
||||
self.assertEqual(norm.gamma.dtype.base_dtype, 'float32')
|
||||
|
||||
|
||||
class BatchNormalizationV1Test(test.TestCase):
|
||||
|
||||
|
@ -24,6 +24,31 @@ exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "policy",
|
||||
srcs = [
|
||||
"policy.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "policy_test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"policy_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":policy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "autocast_variable",
|
||||
srcs = [
|
||||
@ -52,3 +77,16 @@ py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "keras_test",
|
||||
size = "medium",
|
||||
srcs = ["keras_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:one_device_strategy",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,360 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests mixed precision works correctly with Keras layers and models."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import one_device_strategy
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class AssertTypeLayer(base_layer.Layer):
|
||||
"""A layer which asserts it's inputs are a certain type."""
|
||||
|
||||
def __init__(self, assert_type=None, **kwargs):
|
||||
self._assert_type = assert_type
|
||||
super(AssertTypeLayer, self).__init__(**kwargs)
|
||||
|
||||
def assert_input_types(self, inputs):
|
||||
"""Asserts `inputs` are of the correct type. Should be called in call()."""
|
||||
if self._assert_type:
|
||||
inputs_flattened = nest.flatten(inputs)
|
||||
for inp in inputs_flattened:
|
||||
assert inp.dtype.base_dtype == self._assert_type, (
|
||||
'Input tensor has type %s which does not match assert type %s' %
|
||||
(inp.dtype.name, self._assert_type.name))
|
||||
|
||||
|
||||
class AddLayer(AssertTypeLayer):
|
||||
"""A layer which adds it's input to a scalar variable."""
|
||||
|
||||
def __init__(self, regularizer=None, use_operator=False, **kwargs):
|
||||
"""Initializes the AddLayer.
|
||||
|
||||
Args:
|
||||
regularizer: The regularizer on the scalar variable.
|
||||
use_operator: If True, add using the + operator. If False, add using
|
||||
tf.add.
|
||||
**kwargs: Passed to AssertTypeLayer constructor.
|
||||
"""
|
||||
self._regularizer = regularizer
|
||||
self._use_operator = use_operator
|
||||
super(AddLayer, self).__init__(**kwargs)
|
||||
|
||||
def build(self, _):
|
||||
self.v = self.add_weight('v', (), initializer='ones',
|
||||
regularizer=self._regularizer)
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
self.assert_input_types(inputs)
|
||||
assert inputs.dtype == self.v.dtype
|
||||
return self._add(inputs, self.v)
|
||||
|
||||
def _add(self, x, y):
|
||||
if self._use_operator:
|
||||
return x + y
|
||||
else:
|
||||
return math_ops.add(x, y)
|
||||
|
||||
|
||||
class AddLayerWithoutAutoCast(AddLayer):
|
||||
"""Same as AddLayer, but does not use AutoCastVariables."""
|
||||
|
||||
def build(self, _):
|
||||
dtype = self.dtype
|
||||
if dtype in ('float16', 'bfloat16'):
|
||||
dtype = 'float32'
|
||||
self.v = self.add_weight('v', (), initializer='ones', dtype=dtype,
|
||||
experimental_autocast=False,
|
||||
regularizer=self._regularizer)
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
self.assert_input_types(inputs)
|
||||
assert self.v.dtype in (dtypes.float32, dtypes.float64)
|
||||
return self._add(inputs, math_ops.cast(self.v, inputs.dtype))
|
||||
|
||||
|
||||
class IdentityRegularizer(regularizers.Regularizer):
|
||||
|
||||
def __call__(self, x):
|
||||
assert x.dtype == dtypes.float32
|
||||
return array_ops.identity(x)
|
||||
|
||||
|
||||
def create_one_device_strategy():
|
||||
return one_device_strategy.OneDeviceStrategy('cpu:0')
|
||||
|
||||
|
||||
def create_mirrored_strategy():
|
||||
if context.num_gpus() >= 1:
|
||||
return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0'])
|
||||
else:
|
||||
return mirrored_strategy.MirroredStrategy(['cpu:0'])
|
||||
|
||||
|
||||
TESTCASES = ({
|
||||
'testcase_name': 'base',
|
||||
'strategy_fn': create_one_device_strategy
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'strategy_fn': create_mirrored_strategy
|
||||
})
|
||||
|
||||
|
||||
class KerasLayerTest(test.TestCase, parameterized.TestCase):
|
||||
"""Test mixed precision with Keras layers."""
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_variables_in_float32(self, strategy_fn):
|
||||
x = constant_op.constant([1.], dtype=dtypes.float16)
|
||||
with strategy_fn().scope():
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
layer = AddLayer(assert_type=dtypes.float16)
|
||||
y = layer(x)
|
||||
self.assertEqual(layer.v.dtype, dtypes.float32)
|
||||
self.assertEqual(y.dtype, dtypes.float16)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(self.evaluate(y), 2.)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_layer_with_non_autocast_variable(self, strategy_fn):
|
||||
x = constant_op.constant([1.], dtype=dtypes.float16)
|
||||
with strategy_fn().scope():
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
layer = AddLayerWithoutAutoCast(assert_type=dtypes.float16)
|
||||
y = layer(x)
|
||||
self.assertEqual(layer.v.dtype, dtypes.float32)
|
||||
self.assertEqual(y.dtype, dtypes.float16)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(self.evaluate(y), 2.)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_layer_regularizer_runs_in_float32(self, strategy_fn):
|
||||
x = constant_op.constant([1.], dtype=dtypes.float16)
|
||||
with strategy_fn().scope():
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
# Test on AddLayer
|
||||
layer = AddLayer(assert_type=dtypes.float16,
|
||||
regularizer=IdentityRegularizer())
|
||||
layer(x)
|
||||
(regularizer_loss,) = layer.losses
|
||||
self.assertEqual(regularizer_loss.dtype, dtypes.float32)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(self.evaluate(regularizer_loss), 1.)
|
||||
|
||||
# Test on AddLayerWithoutAutoCast
|
||||
layer = AddLayerWithoutAutoCast(assert_type=dtypes.float16,
|
||||
regularizer=IdentityRegularizer())
|
||||
layer(x)
|
||||
(regularizer_loss,) = layer.losses
|
||||
self.assertEqual(regularizer_loss.dtype, dtypes.float32)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(self.evaluate(regularizer_loss), 1.)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_passing_policy_to_layer(self, strategy_fn):
|
||||
x = constant_op.constant([1.], dtype=dtypes.float16)
|
||||
with strategy_fn().scope():
|
||||
# Passing a Policy to 'dtype' sets the policy for that layer.
|
||||
layer = AddLayer(assert_type=dtypes.float16,
|
||||
dtype=policy.Policy('infer_float32_vars'))
|
||||
# layer.dtype refers to the variable dtype
|
||||
self.assertEqual(layer.dtype, dtypes.float32)
|
||||
layer(x)
|
||||
self.assertEqual(layer.v.dtype, dtypes.float32)
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
# Passing a Policy to dtype overrides the global Policy
|
||||
layer = AddLayer(assert_type=dtypes.float16,
|
||||
dtype=policy.Policy('infer'))
|
||||
# layer dtype is not yet known
|
||||
self.assertEqual(layer.dtype, None)
|
||||
layer(x)
|
||||
self.assertEqual(layer.v.dtype, dtypes.float16)
|
||||
self.assertEqual(layer.dtype, dtypes.float16)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_gradient(self, strategy_fn):
|
||||
x = constant_op.constant([1.], dtype=dtypes.float16)
|
||||
with strategy_fn().scope() as strategy:
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
layer = AddLayer(assert_type=dtypes.float16)
|
||||
def run_fn():
|
||||
with backprop.GradientTape() as tape:
|
||||
y = layer(x)
|
||||
# Divide by num_replicas_in_sync, as the effective total loss is the
|
||||
# sum of each of the replica's losses.
|
||||
y /= strategy.num_replicas_in_sync
|
||||
|
||||
# Learning rate is small enough that if applied to a float16 variable,
|
||||
# the variable will not change. So this tests the learning rate is not
|
||||
# applied to a float16 value, but instead the float32 variable.
|
||||
opt = gradient_descent.SGD(2 ** -14)
|
||||
grad = tape.gradient(y, layer.v)
|
||||
return opt.apply_gradients([(grad, layer.v)])
|
||||
|
||||
op = strategy.experimental_run(run_fn)
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(op)
|
||||
# The gradient with respective to the variable is 1. Since the
|
||||
# variable is initialized with 1 and the learning rate is 2**-14, the
|
||||
# new variable value should be: init_val - gradient * learning_rate,
|
||||
# which is 1 - 1 * 2**-14
|
||||
self.assertEqual(self.evaluate(layer.v), 1 - 2 ** -14)
|
||||
|
||||
|
||||
class KerasModelTest(test.TestCase, parameterized.TestCase):
|
||||
"""Test mixed precision with Keras models."""
|
||||
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': 'base',
|
||||
'strategy_fn': create_one_device_strategy,
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
}, {
|
||||
'testcase_name': 'operator',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'use_operator': True
|
||||
}, {
|
||||
'testcase_name': 'regularizer',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'use_regularizer': True
|
||||
})
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_model(self, strategy_fn, use_operator=False, use_regularizer=False):
|
||||
regularizer = IdentityRegularizer() if use_regularizer else None
|
||||
with strategy_fn().scope():
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
x = layers.Input(shape=(), batch_size=2, dtype=dtypes.float16)
|
||||
layer = AddLayer(assert_type=dtypes.float16, use_operator=use_operator,
|
||||
regularizer=regularizer)
|
||||
y = layer(x)
|
||||
y = math_ops.cast(y, dtypes.float32)
|
||||
model = models.Model(inputs=x, outputs=y)
|
||||
|
||||
def loss_fn(y_true, y_pred):
|
||||
del y_true
|
||||
return math_ops.reduce_mean(y_pred)
|
||||
|
||||
# Learning rate is small enough that if applied to a float16 variable,
|
||||
# the variable will not change. So this tests the learning rate not
|
||||
# applied to a float16 value, but instead the float32 variable.
|
||||
opt = gradient_descent.SGD(2 ** -14)
|
||||
model.compile(opt, loss=loss_fn)
|
||||
|
||||
self.assertEqual(backend.eval(layer.v), 1)
|
||||
x = np.ones((2, 1))
|
||||
y = np.ones((2, 1))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
|
||||
model.fit(dataset)
|
||||
# Variable starts at 1, and should have gradient of 2 ** -14 subtracted
|
||||
# from it.
|
||||
expected = 1 - 2 ** -14
|
||||
if use_regularizer:
|
||||
# Regularizer adds another 2 ** -14 to the gradient.
|
||||
expected -= 2 ** -14
|
||||
self.assertEqual(backend.eval(layer.v), expected)
|
||||
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': 'base',
|
||||
'strategy_fn': create_one_device_strategy,
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
})
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_advanced_model(self, strategy_fn):
|
||||
|
||||
# The advanced model tests mixed-precision-related features that would occur
|
||||
# in a resnet50 model. It tests a model that has:
|
||||
# * Multiple layers, some which use auto-cast variables and some which do
|
||||
# not
|
||||
# * Regularization on some variables and not others.
|
||||
|
||||
strategy = strategy_fn()
|
||||
|
||||
learning_rate = 2 ** -14
|
||||
|
||||
with strategy.scope():
|
||||
with policy.policy_scope(policy.Policy('infer_float32_vars')):
|
||||
x = layers.Input(shape=(), batch_size=2, dtype=dtypes.float16)
|
||||
layer1 = AddLayer(assert_type=dtypes.float16,
|
||||
regularizer=IdentityRegularizer(), use_operator=True)
|
||||
layer2 = AddLayerWithoutAutoCast(assert_type=dtypes.float16,
|
||||
use_operator=True)
|
||||
layer3 = AddLayer(assert_type=dtypes.float16, use_operator=False)
|
||||
layer4 = AddLayerWithoutAutoCast(assert_type=dtypes.float16,
|
||||
regularizer=IdentityRegularizer(),
|
||||
use_operator=False)
|
||||
y = layer1(x)
|
||||
y = layer2(y)
|
||||
y = layer3(y)
|
||||
y = layer4(y)
|
||||
y = math_ops.cast(y, dtypes.float32)
|
||||
model = models.Model(inputs=x, outputs=y)
|
||||
|
||||
def loss_fn(y_true, y_pred):
|
||||
self.assertEqual(y_true.dtype, dtypes.float32)
|
||||
self.assertEqual(y_pred.dtype, dtypes.float32)
|
||||
return math_ops.reduce_mean(y_pred)
|
||||
|
||||
opt = gradient_descent.SGD(learning_rate)
|
||||
model.compile(opt, loss=loss_fn)
|
||||
|
||||
x = np.ones((2, 1))
|
||||
y = np.ones((2, 1))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
|
||||
model.fit(dataset)
|
||||
for layer in (layer1, layer2, layer3, layer4):
|
||||
if layer.losses:
|
||||
# Layer has weight regularizer
|
||||
self.assertEqual(backend.eval(layer.v), 1 - 2 * learning_rate)
|
||||
else:
|
||||
# Layer does not have weight regularizer
|
||||
self.assertEqual(backend.eval(layer.v), 1 - learning_rate)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
160
tensorflow/python/keras/mixed_precision/experimental/policy.py
Normal file
160
tensorflow/python/keras/mixed_precision/experimental/policy.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains the Policy class for mixed precision training."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.Policy')
|
||||
class Policy(object):
|
||||
"""A mixed precision policy for a Keras layer.
|
||||
|
||||
A mixed precision policy determines the floating-point dtype that Keras layers
|
||||
should create variables in. For non-default policies, if the variable dtype
|
||||
does not match the input dtype, variables will automatically be casted to the
|
||||
input dtype to avoid type errors. Policies can be passed to the 'dtype'
|
||||
argument of layer constructors, or a global policy can be set with
|
||||
'set_policy'.
|
||||
|
||||
In the near future, policies will also determine the computation dtype of
|
||||
layers, as well as the loss scaling algorithm.
|
||||
|
||||
Policies are intended to enable mixed precision training, which require using
|
||||
float32 variables and [b]float16 computations for most layers. The term "mixed
|
||||
precision" refers to the use of both float16 (or bfloat16) and float32 in a
|
||||
model. See https://arxiv.org/abs/1710.03740 for more information on mixed
|
||||
precision training.
|
||||
|
||||
Policies are constructed by passing a string to the `name` constructor
|
||||
argument. `name` determines the behavior of the policy. Currently, `name` can
|
||||
be one of the following values.
|
||||
|
||||
* 'infer': Infer the variable and computation dtypes from the input dtype.
|
||||
This is the default behavior.
|
||||
* 'infer_float32_vars': Infer the computation dtypes from the input
|
||||
dtype, but create variables in float32. Variables will be casted to the
|
||||
computation dtype. This is intended to enable mixed precision. Users can
|
||||
cast tensors to float16 before passing them to a layer, which causes the
|
||||
layer to run it's computation in float16 while keeping variables in
|
||||
float32.
|
||||
|
||||
To use mixed precision in a model, the 'infer_float32_vars' policy can be used
|
||||
alongside float16 input tensors, which results in float16 computations and
|
||||
float32 variables. For example:
|
||||
|
||||
```python
|
||||
tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars')
|
||||
model = tf.keras.models.Sequential(
|
||||
tf.keras.layers.Input((100,), dtype='float16'),
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
|
||||
tf.keras.layers.Activation('Softmax')
|
||||
)
|
||||
```
|
||||
|
||||
Alternatively, the policy can be passed to individual layers instead of
|
||||
setting the global policy with `set_policy`:
|
||||
|
||||
```python
|
||||
policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
|
||||
model = tf.keras.models.Sequential(
|
||||
tf.keras.layers.Input((100,), dtype='float16'),
|
||||
tf.keras.layers.Dense(10, dtype=policy),
|
||||
tf.keras.layers.Dense(10, dtype=policy),
|
||||
tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
|
||||
tf.keras.layers.Activation('Softmax')
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
if name == 'infer':
|
||||
self._default_variable_dtype = None
|
||||
elif name == 'infer_float32_vars':
|
||||
self._default_variable_dtype = 'float32'
|
||||
else:
|
||||
raise ValueError('"name" argument to Policy constructor must be "infer" '
|
||||
'or "infer_float32_vars", but got: %s' % name)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Returns the name of the policy: "infer" or "infer_float32_vars."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def default_variable_dtype(self):
|
||||
"""Returns the default variable dtype of this policy.
|
||||
|
||||
This is the dtype layers will create their variables in, unless a layer
|
||||
explicit chooses a different dtype. Layers will cast variables to the
|
||||
appropriate dtype to avoid type errors.
|
||||
|
||||
Returns:
|
||||
The default variable dtype of this policy, or None if the default variable
|
||||
dtype should be derived from the inputs.
|
||||
"""
|
||||
return self._default_variable_dtype
|
||||
|
||||
@property
|
||||
def should_cast_variables(self):
|
||||
"""Returns true if variables should be casted."""
|
||||
return self.default_variable_dtype is not None
|
||||
|
||||
# TODO(reedwm): Implement get_config/from_config.
|
||||
|
||||
|
||||
# TODO(reedwm): Make this thread local?
|
||||
_global_policy = Policy('infer')
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.global_policy')
|
||||
def global_policy():
|
||||
"""Returns the global Policy.
|
||||
|
||||
The global policy is the default policy used for layers, if no policy is
|
||||
passed to the layer constructor. When TensorFlow starts, the global policy is
|
||||
set to an "infer" policy, and can be changed with `set_policy`.
|
||||
|
||||
Returns:
|
||||
The global Policy.
|
||||
"""
|
||||
return _global_policy
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.set_policy')
|
||||
def set_policy(policy):
|
||||
"""Sets the global Policy."""
|
||||
global _global_policy
|
||||
if not isinstance(policy, Policy):
|
||||
policy = Policy(policy)
|
||||
_global_policy = policy
|
||||
|
||||
|
||||
# TODO(reedwm): Make this thread local
|
||||
@contextlib.contextmanager
|
||||
def policy_scope(policy):
|
||||
old_policy = _global_policy
|
||||
try:
|
||||
set_policy(policy)
|
||||
yield
|
||||
finally:
|
||||
set_policy(old_policy)
|
@ -0,0 +1,69 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests Policies."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PolicyTest(test.TestCase):
|
||||
"""Tests Policies."""
|
||||
|
||||
def test_infer(self):
|
||||
policy = mp_policy.Policy('infer')
|
||||
self.assertEqual(policy.name, 'infer')
|
||||
self.assertEqual(policy.default_variable_dtype, None)
|
||||
|
||||
def test_infer_float32_vars(self):
|
||||
policy = mp_policy.Policy('infer_float32_vars')
|
||||
self.assertEqual(policy.name, 'infer_float32_vars')
|
||||
self.assertEqual(policy.default_variable_dtype, 'float32')
|
||||
|
||||
def test_global_policy(self):
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer')
|
||||
default_policy = mp_policy.global_policy()
|
||||
try:
|
||||
mp_policy.set_policy('infer_float32_vars')
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
|
||||
self.assertEqual(mp_policy.global_policy().default_variable_dtype,
|
||||
'float32')
|
||||
with ops.Graph().as_default(): # Policies are not associated with a graph
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
|
||||
mp_policy.set_policy('infer')
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer')
|
||||
self.assertEqual(mp_policy.global_policy().default_variable_dtype, None)
|
||||
policy = mp_policy.Policy('infer_float32_vars')
|
||||
mp_policy.set_policy(policy)
|
||||
self.assertIs(mp_policy.global_policy(), policy)
|
||||
finally:
|
||||
mp_policy.set_policy(default_policy)
|
||||
|
||||
def test_policy_scope(self):
|
||||
with mp_policy.policy_scope('infer_float32_vars'):
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
|
||||
with mp_policy.policy_scope('infer'):
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer')
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
|
||||
self.assertEqual(mp_policy.global_policy().name, 'infer')
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -307,7 +307,8 @@ class Layer(base_layer.Layer):
|
||||
use_resource=None,
|
||||
synchronization=vs.VariableSynchronization.AUTO,
|
||||
aggregation=vs.VariableAggregation.NONE,
|
||||
partitioner=None):
|
||||
partitioner=None,
|
||||
**kwargs):
|
||||
"""Adds a new variable to the layer, or gets an existing one; returns it.
|
||||
|
||||
Arguments:
|
||||
@ -342,6 +343,7 @@ class Layer(base_layer.Layer):
|
||||
`tf.variable_axis_size_partitioner`. For more details, see the
|
||||
documentation of `tf.get_variable` and the "Variable Partitioners
|
||||
and Sharding" section of the API guide.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The created variable. Usually either a `Variable` or `ResourceVariable`
|
||||
@ -354,6 +356,9 @@ class Layer(base_layer.Layer):
|
||||
ValueError: When trainable has been set to True with synchronization
|
||||
set as `ON_READ`.
|
||||
"""
|
||||
for kwarg in kwargs:
|
||||
if kwarg != 'experimental_autocast':
|
||||
raise TypeError('Unknown keyword argument:', kwarg)
|
||||
if self._keras_style:
|
||||
return super(Layer, self).add_weight(
|
||||
name=name,
|
||||
@ -366,7 +371,8 @@ class Layer(base_layer.Layer):
|
||||
use_resource=use_resource,
|
||||
synchronization=vs.VariableSynchronization.AUTO,
|
||||
aggregation=vs.VariableAggregation.NONE,
|
||||
partitioner=partitioner)
|
||||
partitioner=partitioner,
|
||||
**kwargs)
|
||||
|
||||
if synchronization == vs.VariableSynchronization.ON_READ:
|
||||
if trainable:
|
||||
@ -433,7 +439,8 @@ class Layer(base_layer.Layer):
|
||||
use_resource=use_resource,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation,
|
||||
getter=vs.get_variable)
|
||||
getter=vs.get_variable,
|
||||
**kwargs)
|
||||
|
||||
if regularizer:
|
||||
if (ops.executing_eagerly_outside_functions()
|
||||
|
@ -90,6 +90,8 @@ KERAS_API_INIT_FILES = [
|
||||
"keras/layers/experimental/__init__.py",
|
||||
"keras/losses/__init__.py",
|
||||
"keras/metrics/__init__.py",
|
||||
"keras/mixed_precision/__init__.py",
|
||||
"keras/mixed_precision/experimental/__init__.py",
|
||||
"keras/models/__init__.py",
|
||||
"keras/optimizers/__init__.py",
|
||||
"keras/optimizers/schedules/__init__.py",
|
||||
|
@ -114,6 +114,8 @@ KERAS_API_INIT_FILES_V1 = [
|
||||
"keras/layers/experimental/__init__.py",
|
||||
"keras/losses/__init__.py",
|
||||
"keras/metrics/__init__.py",
|
||||
"keras/mixed_precision/__init__.py",
|
||||
"keras/mixed_precision/experimental/__init__.py",
|
||||
"keras/models/__init__.py",
|
||||
"keras/optimizers/__init__.py",
|
||||
"keras/optimizers/schedules/__init__.py",
|
||||
|
@ -0,0 +1,21 @@
|
||||
path: "tensorflow.keras.mixed_precision.experimental.Policy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.Policy\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "default_variable_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_cast_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
path: "tensorflow.keras.mixed_precision.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Policy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "global_policy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_policy"
|
||||
argspec: "args=[\'policy\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.keras.mixed_precision"
|
||||
tf_module {
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
@ -56,6 +56,10 @@ tf_module {
|
||||
name: "metrics"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "mixed_precision"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "models"
|
||||
mtype: "<type \'module\'>"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -118,7 +118,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -118,7 +118,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -116,7 +116,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -116,7 +116,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -116,7 +116,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -114,7 +114,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -117,7 +117,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -118,7 +118,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -118,7 +118,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -124,7 +124,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -129,7 +129,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -124,7 +124,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -123,7 +123,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -0,0 +1,21 @@
|
||||
path: "tensorflow.keras.mixed_precision.experimental.Policy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.Policy\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "default_variable_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_cast_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
path: "tensorflow.keras.mixed_precision.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Policy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "global_policy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_policy"
|
||||
argspec: "args=[\'policy\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.keras.mixed_precision"
|
||||
tf_module {
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
@ -56,6 +56,10 @@ tf_module {
|
||||
name: "metrics"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "mixed_precision"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "models"
|
||||
mtype: "<type \'module\'>"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -125,7 +125,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -132,7 +132,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -128,7 +128,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
@ -124,7 +124,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
Loading…
Reference in New Issue
Block a user