diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index a2f023b8101..9fd6a4e0fec 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -64,6 +64,7 @@ py_library( ":saving", "//tensorflow/python:training", "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", + "//tensorflow/python/keras/mixed_precision/experimental:policy", "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/saved_model", "@keras_applications_archive//:keras_applications", @@ -171,6 +172,8 @@ py_library( "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", + "//tensorflow/python/keras/mixed_precision/experimental:policy", "//tensorflow/python/training/tracking:data_structures", "//tensorflow/tools/docs:doc_controls", "@six_archive//:six", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index d3218208fdb..404d3023344 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -26,6 +26,7 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.distribute import values as distribute_values from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes @@ -38,6 +39,8 @@ from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec +from tensorflow.python.keras.mixed_precision.experimental import autocast_variable +from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils # A module that only depends on `keras.layers` import these from here. @@ -172,7 +175,9 @@ class Layer(trackable.Trackable): # A dictionary that maps metric names to metric result tensors. The results # are the running averages of metric values over an epoch. self._metrics_tensors = {} - self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name + + self._set_dtype_and_policy(dtype) + self._call_fn_args = function_utils.fn_args(self.call) self._compute_previous_mask = ('mask' in self._call_fn_args or hasattr(self, 'compute_mask')) @@ -308,10 +313,13 @@ class Layer(trackable.Trackable): shape = shape or () # Validate optional keyword arguments. for kwarg in kwargs: - if kwarg not in ['getter', 'collections']: + if kwarg not in ['getter', 'collections', 'experimental_autocast']: raise TypeError('Unknown keyword argument:', kwarg) getter = kwargs.pop('getter', None) collections = kwargs.pop('collections', None) + # 'experimental_autocast' can be set to False by the caller to indicate an + # AutoCastVariable should never be created. + autocast = kwargs.pop('experimental_autocast', True) if dtype is None: dtype = self.dtype or backend.floatx() @@ -368,6 +376,12 @@ class Layer(trackable.Trackable): aggregation=aggregation) backend.track_variable(variable) + if autocast and self._mixed_precision_policy.should_cast_variables: + if isinstance(variable, distribute_values.DistributedVariable): + variable = autocast_variable.AutoCastDistributedVariable(variable) + else: + variable = autocast_variable.AutoCastVariable(variable) + if regularizer is not None: # TODO(fchollet): in the future, this should be handled at the # level of variable creation, and weight regularization losses @@ -402,6 +416,7 @@ class Layer(trackable.Trackable): config['batch_input_shape'] = self._batch_input_shape if hasattr(self, 'dtype'): config['dtype'] = self.dtype + # TODO(reedwm): Handle serializing self._mixed_precision_policy. return config @classmethod @@ -588,8 +603,11 @@ class Layer(trackable.Trackable): kwargs['training'] = backend.learning_phase() if not self.dynamic: try: - with base_layer_utils.AutoAddUpdates(self, - inputs) as auto_updater: + with base_layer_utils.autocast_context_manager( + input_list, + self._mixed_precision_policy.should_cast_variables), ( + base_layer_utils.AutoAddUpdates(self, + inputs)) as auto_updater: outputs = self.call(inputs, *args, **kwargs) auto_updater.set_outputs(outputs) @@ -636,7 +654,9 @@ class Layer(trackable.Trackable): # Eager execution on data tensors. with ops.name_scope(self._name_scope()): self._maybe_build(inputs) - outputs = self.call(inputs, *args, **kwargs) + with base_layer_utils.autocast_context_manager( + input_list, self._mixed_precision_policy.should_cast_variables): + outputs = self.call(inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, previous_mask) @@ -1328,6 +1348,24 @@ class Layer(trackable.Trackable): # Methods & attributes below are all private and only used by the framework. # ############################################################################## + def _set_dtype_and_policy(self, dtype): + """Sets self._dtype and self._mixed_precision_policy.""" + if dtype: + if isinstance(dtype, policy.Policy): + self._mixed_precision_policy = dtype + self._dtype = self._mixed_precision_policy.default_variable_dtype + else: + # If a non-policy dtype is passed, no casting should be done. So we use + # the "infer" policy, which does no casting. + self._mixed_precision_policy = policy.Policy('infer') + self._dtype = dtypes.as_dtype(dtype).name + else: + self._mixed_precision_policy = policy.global_policy() + # If the global policy has not been set, it will be an "infer" policy + # without a default variable dtype, and so self._dtype will be None. In + # that case, self._dtype will be set when the layer is built or called. + self._dtype = self._mixed_precision_policy.default_variable_dtype + def _name_scope(self): return self.name diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 34b007ea953..95f709fd425 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -514,3 +514,32 @@ class AutoAddUpdates(object): self.layer.add_update(list(unconditional_updates)) if conditional_updates: self.layer.add_update(list(conditional_updates), inputs=self.inputs) + + +def _get_var_read_dtype(input_list, should_cast): + """Gets the dtype that AutoCastVariables should be read in.""" + if should_cast and input_list and input_list[0].dtype.is_floating: + return input_list[0].dtype.base_dtype + else: + return None + + +def autocast_context_manager(input_list, should_cast): + """Returns a context manager to autocast AutoCastVariables. + + Under this context manager, if `should_cast` is True, AutoCastVariables will + be casted. If `should_cast` is False, AutoCastVariables will not be casted, + which can be used to disable autocasting if nested under another + call to `autocast_context_manager`. + + Args: + input_list: The inputs to the layer with the AutoCastVariables. + should_cast: Whether AutoCastVariables should be casted. + + Returns: + A context manager to automatically cast AutoCastVariables. + """ + var_read_dtype = _get_var_read_dtype(input_list, should_cast) + return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access + var_read_dtype) + diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 9178feab781..0bc10b0fb99 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -36,6 +36,7 @@ from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils @@ -209,6 +210,12 @@ class Network(base_layer.Layer): self._trackable_saver = ( trackable_utils.saver_with_op_caching(self)) + # Networks do not need to do any casting of inputs or variables, because + # each of its layers will handle casting through the layer's own + # implementation. Therefore networks use the 'infer' policy, which does no + # casting. + self._mixed_precision_policy = policy.Policy('infer') + @trackable.no_automatic_dependency_tracking def _init_graph_network(self, inputs, outputs, name=None): self._call_convention = (base_layer_utils diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 7c58ed7becc..5abac83cfc7 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -1001,7 +1001,11 @@ class Dense(Layer): output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: - inputs = math_ops.cast(inputs, self.dtype) + # Cast the inputs to self.dtype, which is the variable dtype. We do not + # cast if `should_cast_variables` is True, as in that case the variable + # will be automatically casted to inputs.dtype. + if not self._mixed_precision_policy.should_cast_variables: + inputs = math_ops.cast(inputs, self.dtype) outputs = gen_math_ops.mat_mul(inputs, self.kernel) if self.use_bias: outputs = nn.bias_add(outputs, self.bias) diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 6ba59b2ff33..ad7b9978d13 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -298,6 +299,14 @@ class CoreLayersTest(keras_parameterized.TestCase): outputs = layer(inputs) self.assertEqual(outputs.dtype, 'float32') + def test_dense_with_policy(self): + inputs = ops.convert_to_tensor( + np.random.randint(low=0, high=7, size=(2, 2)), dtype='float16') + layer = keras.layers.Dense(5, dtype=policy.Policy('infer_float32_vars')) + outputs = layer(inputs) + self.assertEqual(outputs.dtype, 'float16') + self.assertEqual(layer.kernel.dtype, 'float32') + def test_dense_regularization(self): layer = keras.layers.Dense( 3, diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index fdda7c92002..dab27d2d254 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -334,7 +334,8 @@ class BatchNormalizationV2(Layer): initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, - trainable=True) + trainable=True, + experimental_autocast=False) else: self.gamma = None if self.fused: @@ -349,7 +350,8 @@ class BatchNormalizationV2(Layer): initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, - trainable=True) + trainable=True, + experimental_autocast=False) else: self.beta = None if self.fused: @@ -370,7 +372,8 @@ class BatchNormalizationV2(Layer): initializer=self.moving_mean_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, - aggregation=tf_variables.VariableAggregation.MEAN) + aggregation=tf_variables.VariableAggregation.MEAN, + experimental_autocast=False) self.moving_variance = self.add_weight( name='moving_variance', @@ -379,7 +382,8 @@ class BatchNormalizationV2(Layer): initializer=self.moving_variance_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, - aggregation=tf_variables.VariableAggregation.MEAN) + aggregation=tf_variables.VariableAggregation.MEAN, + experimental_autocast=False) if self.renorm: # Create variables to maintain the moving mean and standard deviation. @@ -390,6 +394,7 @@ class BatchNormalizationV2(Layer): # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): + """Create a renorm variable.""" var = self.add_weight( name=name, shape=shape, @@ -397,7 +402,8 @@ class BatchNormalizationV2(Layer): initializer=init_ops.zeros_initializer(), synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, - aggregation=tf_variables.VariableAggregation.MEAN) + aggregation=tf_variables.VariableAggregation.MEAN, + experimental_autocast=False) return var with distribution_strategy_context.get_strategy( @@ -958,7 +964,8 @@ class LayerNormalization(Layer): initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, - trainable=True) + trainable=True, + experimental_autocast=False) else: self.gamma = None @@ -969,7 +976,8 @@ class LayerNormalization(Layer): initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, - trainable=True) + trainable=True, + experimental_autocast=False) else: self.beta = None diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index 3815d1e673d..00e8d831b69 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers import normalization +from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent @@ -143,6 +144,19 @@ class BatchNormalizationTest(keras_parameterized.TestCase): _run_batchnorm_correctness_test( normalization.BatchNormalization, dtype='float16', fused=False) + @tf_test_util.run_in_graph_and_eager_modes + def test_batchnorm_policy(self): + norm = keras.layers.BatchNormalization( + axis=-1, + input_shape=(4, 4, 3), + momentum=0.8, + dtype=policy.Policy('infer_float32_vars')) + x = np.random.normal(size=(10, 4, 4, 3)).astype('float16') + y = norm(x) + self.assertEqual(y.dtype, 'float16') + self.assertEqual(norm.beta.dtype.base_dtype, 'float32') + self.assertEqual(norm.gamma.dtype.base_dtype, 'float32') + class BatchNormalizationV1Test(test.TestCase): diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index f994ab9c70a..900d4e5aa88 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -24,6 +24,31 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") +py_library( + name = "policy", + srcs = [ + "policy.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + ], +) + +py_test( + name = "policy_test", + size = "medium", + srcs = [ + "policy_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":policy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "autocast_variable", srcs = [ @@ -52,3 +77,16 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) + +py_test( + name = "keras_test", + size = "medium", + srcs = ["keras_test.py"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:one_device_strategy", + "//tensorflow/python/keras", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py new file mode 100644 index 00000000000..66bd1e9d7e2 --- /dev/null +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -0,0 +1,360 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests mixed precision works correctly with Keras layers and models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import one_device_strategy +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend +from tensorflow.python.keras import layers +from tensorflow.python.keras import models +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +class AssertTypeLayer(base_layer.Layer): + """A layer which asserts it's inputs are a certain type.""" + + def __init__(self, assert_type=None, **kwargs): + self._assert_type = assert_type + super(AssertTypeLayer, self).__init__(**kwargs) + + def assert_input_types(self, inputs): + """Asserts `inputs` are of the correct type. Should be called in call().""" + if self._assert_type: + inputs_flattened = nest.flatten(inputs) + for inp in inputs_flattened: + assert inp.dtype.base_dtype == self._assert_type, ( + 'Input tensor has type %s which does not match assert type %s' % + (inp.dtype.name, self._assert_type.name)) + + +class AddLayer(AssertTypeLayer): + """A layer which adds it's input to a scalar variable.""" + + def __init__(self, regularizer=None, use_operator=False, **kwargs): + """Initializes the AddLayer. + + Args: + regularizer: The regularizer on the scalar variable. + use_operator: If True, add using the + operator. If False, add using + tf.add. + **kwargs: Passed to AssertTypeLayer constructor. + """ + self._regularizer = regularizer + self._use_operator = use_operator + super(AddLayer, self).__init__(**kwargs) + + def build(self, _): + self.v = self.add_weight('v', (), initializer='ones', + regularizer=self._regularizer) + self.built = True + + def call(self, inputs): + self.assert_input_types(inputs) + assert inputs.dtype == self.v.dtype + return self._add(inputs, self.v) + + def _add(self, x, y): + if self._use_operator: + return x + y + else: + return math_ops.add(x, y) + + +class AddLayerWithoutAutoCast(AddLayer): + """Same as AddLayer, but does not use AutoCastVariables.""" + + def build(self, _): + dtype = self.dtype + if dtype in ('float16', 'bfloat16'): + dtype = 'float32' + self.v = self.add_weight('v', (), initializer='ones', dtype=dtype, + experimental_autocast=False, + regularizer=self._regularizer) + self.built = True + + def call(self, inputs): + self.assert_input_types(inputs) + assert self.v.dtype in (dtypes.float32, dtypes.float64) + return self._add(inputs, math_ops.cast(self.v, inputs.dtype)) + + +class IdentityRegularizer(regularizers.Regularizer): + + def __call__(self, x): + assert x.dtype == dtypes.float32 + return array_ops.identity(x) + + +def create_one_device_strategy(): + return one_device_strategy.OneDeviceStrategy('cpu:0') + + +def create_mirrored_strategy(): + if context.num_gpus() >= 1: + return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0']) + else: + return mirrored_strategy.MirroredStrategy(['cpu:0']) + + +TESTCASES = ({ + 'testcase_name': 'base', + 'strategy_fn': create_one_device_strategy +}, { + 'testcase_name': 'distribute', + 'strategy_fn': create_mirrored_strategy +}) + + +class KerasLayerTest(test.TestCase, parameterized.TestCase): + """Test mixed precision with Keras layers.""" + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_variables_in_float32(self, strategy_fn): + x = constant_op.constant([1.], dtype=dtypes.float16) + with strategy_fn().scope(): + with policy.policy_scope('infer_float32_vars'): + layer = AddLayer(assert_type=dtypes.float16) + y = layer(x) + self.assertEqual(layer.v.dtype, dtypes.float32) + self.assertEqual(y.dtype, dtypes.float16) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(y), 2.) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_layer_with_non_autocast_variable(self, strategy_fn): + x = constant_op.constant([1.], dtype=dtypes.float16) + with strategy_fn().scope(): + with policy.policy_scope('infer_float32_vars'): + layer = AddLayerWithoutAutoCast(assert_type=dtypes.float16) + y = layer(x) + self.assertEqual(layer.v.dtype, dtypes.float32) + self.assertEqual(y.dtype, dtypes.float16) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(y), 2.) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_layer_regularizer_runs_in_float32(self, strategy_fn): + x = constant_op.constant([1.], dtype=dtypes.float16) + with strategy_fn().scope(): + with policy.policy_scope('infer_float32_vars'): + # Test on AddLayer + layer = AddLayer(assert_type=dtypes.float16, + regularizer=IdentityRegularizer()) + layer(x) + (regularizer_loss,) = layer.losses + self.assertEqual(regularizer_loss.dtype, dtypes.float32) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(regularizer_loss), 1.) + + # Test on AddLayerWithoutAutoCast + layer = AddLayerWithoutAutoCast(assert_type=dtypes.float16, + regularizer=IdentityRegularizer()) + layer(x) + (regularizer_loss,) = layer.losses + self.assertEqual(regularizer_loss.dtype, dtypes.float32) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(regularizer_loss), 1.) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_passing_policy_to_layer(self, strategy_fn): + x = constant_op.constant([1.], dtype=dtypes.float16) + with strategy_fn().scope(): + # Passing a Policy to 'dtype' sets the policy for that layer. + layer = AddLayer(assert_type=dtypes.float16, + dtype=policy.Policy('infer_float32_vars')) + # layer.dtype refers to the variable dtype + self.assertEqual(layer.dtype, dtypes.float32) + layer(x) + self.assertEqual(layer.v.dtype, dtypes.float32) + with policy.policy_scope('infer_float32_vars'): + # Passing a Policy to dtype overrides the global Policy + layer = AddLayer(assert_type=dtypes.float16, + dtype=policy.Policy('infer')) + # layer dtype is not yet known + self.assertEqual(layer.dtype, None) + layer(x) + self.assertEqual(layer.v.dtype, dtypes.float16) + self.assertEqual(layer.dtype, dtypes.float16) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_gradient(self, strategy_fn): + x = constant_op.constant([1.], dtype=dtypes.float16) + with strategy_fn().scope() as strategy: + with policy.policy_scope('infer_float32_vars'): + layer = AddLayer(assert_type=dtypes.float16) + def run_fn(): + with backprop.GradientTape() as tape: + y = layer(x) + # Divide by num_replicas_in_sync, as the effective total loss is the + # sum of each of the replica's losses. + y /= strategy.num_replicas_in_sync + + # Learning rate is small enough that if applied to a float16 variable, + # the variable will not change. So this tests the learning rate is not + # applied to a float16 value, but instead the float32 variable. + opt = gradient_descent.SGD(2 ** -14) + grad = tape.gradient(y, layer.v) + return opt.apply_gradients([(grad, layer.v)]) + + op = strategy.experimental_run(run_fn) + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(op) + # The gradient with respective to the variable is 1. Since the + # variable is initialized with 1 and the learning rate is 2**-14, the + # new variable value should be: init_val - gradient * learning_rate, + # which is 1 - 1 * 2**-14 + self.assertEqual(self.evaluate(layer.v), 1 - 2 ** -14) + + +class KerasModelTest(test.TestCase, parameterized.TestCase): + """Test mixed precision with Keras models.""" + + @parameterized.named_parameters({ + 'testcase_name': 'base', + 'strategy_fn': create_one_device_strategy, + }, { + 'testcase_name': 'distribute', + 'strategy_fn': create_mirrored_strategy, + }, { + 'testcase_name': 'operator', + 'strategy_fn': create_mirrored_strategy, + 'use_operator': True + }, { + 'testcase_name': 'regularizer', + 'strategy_fn': create_mirrored_strategy, + 'use_regularizer': True + }) + @test_util.run_in_graph_and_eager_modes + def test_model(self, strategy_fn, use_operator=False, use_regularizer=False): + regularizer = IdentityRegularizer() if use_regularizer else None + with strategy_fn().scope(): + with policy.policy_scope('infer_float32_vars'): + x = layers.Input(shape=(), batch_size=2, dtype=dtypes.float16) + layer = AddLayer(assert_type=dtypes.float16, use_operator=use_operator, + regularizer=regularizer) + y = layer(x) + y = math_ops.cast(y, dtypes.float32) + model = models.Model(inputs=x, outputs=y) + + def loss_fn(y_true, y_pred): + del y_true + return math_ops.reduce_mean(y_pred) + + # Learning rate is small enough that if applied to a float16 variable, + # the variable will not change. So this tests the learning rate not + # applied to a float16 value, but instead the float32 variable. + opt = gradient_descent.SGD(2 ** -14) + model.compile(opt, loss=loss_fn) + + self.assertEqual(backend.eval(layer.v), 1) + x = np.ones((2, 1)) + y = np.ones((2, 1)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2) + model.fit(dataset) + # Variable starts at 1, and should have gradient of 2 ** -14 subtracted + # from it. + expected = 1 - 2 ** -14 + if use_regularizer: + # Regularizer adds another 2 ** -14 to the gradient. + expected -= 2 ** -14 + self.assertEqual(backend.eval(layer.v), expected) + + @parameterized.named_parameters({ + 'testcase_name': 'base', + 'strategy_fn': create_one_device_strategy, + }, { + 'testcase_name': 'distribute', + 'strategy_fn': create_mirrored_strategy, + }) + @test_util.run_in_graph_and_eager_modes + def test_advanced_model(self, strategy_fn): + + # The advanced model tests mixed-precision-related features that would occur + # in a resnet50 model. It tests a model that has: + # * Multiple layers, some which use auto-cast variables and some which do + # not + # * Regularization on some variables and not others. + + strategy = strategy_fn() + + learning_rate = 2 ** -14 + + with strategy.scope(): + with policy.policy_scope(policy.Policy('infer_float32_vars')): + x = layers.Input(shape=(), batch_size=2, dtype=dtypes.float16) + layer1 = AddLayer(assert_type=dtypes.float16, + regularizer=IdentityRegularizer(), use_operator=True) + layer2 = AddLayerWithoutAutoCast(assert_type=dtypes.float16, + use_operator=True) + layer3 = AddLayer(assert_type=dtypes.float16, use_operator=False) + layer4 = AddLayerWithoutAutoCast(assert_type=dtypes.float16, + regularizer=IdentityRegularizer(), + use_operator=False) + y = layer1(x) + y = layer2(y) + y = layer3(y) + y = layer4(y) + y = math_ops.cast(y, dtypes.float32) + model = models.Model(inputs=x, outputs=y) + + def loss_fn(y_true, y_pred): + self.assertEqual(y_true.dtype, dtypes.float32) + self.assertEqual(y_pred.dtype, dtypes.float32) + return math_ops.reduce_mean(y_pred) + + opt = gradient_descent.SGD(learning_rate) + model.compile(opt, loss=loss_fn) + + x = np.ones((2, 1)) + y = np.ones((2, 1)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2) + model.fit(dataset) + for layer in (layer1, layer2, layer3, layer4): + if layer.losses: + # Layer has weight regularizer + self.assertEqual(backend.eval(layer.v), 1 - 2 * learning_rate) + else: + # Layer does not have weight regularizer + self.assertEqual(backend.eval(layer.v), 1 - learning_rate) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py new file mode 100644 index 00000000000..805031fe7e3 --- /dev/null +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -0,0 +1,160 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains the Policy class for mixed precision training.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.util.tf_export import keras_export + + +@keras_export('keras.mixed_precision.experimental.Policy') +class Policy(object): + """A mixed precision policy for a Keras layer. + + A mixed precision policy determines the floating-point dtype that Keras layers + should create variables in. For non-default policies, if the variable dtype + does not match the input dtype, variables will automatically be casted to the + input dtype to avoid type errors. Policies can be passed to the 'dtype' + argument of layer constructors, or a global policy can be set with + 'set_policy'. + + In the near future, policies will also determine the computation dtype of + layers, as well as the loss scaling algorithm. + + Policies are intended to enable mixed precision training, which require using + float32 variables and [b]float16 computations for most layers. The term "mixed + precision" refers to the use of both float16 (or bfloat16) and float32 in a + model. See https://arxiv.org/abs/1710.03740 for more information on mixed + precision training. + + Policies are constructed by passing a string to the `name` constructor + argument. `name` determines the behavior of the policy. Currently, `name` can + be one of the following values. + + * 'infer': Infer the variable and computation dtypes from the input dtype. + This is the default behavior. + * 'infer_float32_vars': Infer the computation dtypes from the input + dtype, but create variables in float32. Variables will be casted to the + computation dtype. This is intended to enable mixed precision. Users can + cast tensors to float16 before passing them to a layer, which causes the + layer to run it's computation in float16 while keeping variables in + float32. + + To use mixed precision in a model, the 'infer_float32_vars' policy can be used + alongside float16 input tensors, which results in float16 computations and + float32 variables. For example: + + ```python + tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars') + model = tf.keras.models.Sequential( + tf.keras.layers.Input((100,), dtype='float16'), + tf.keras.layers.Dense(10), + tf.keras.layers.Dense(10), + tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')), + tf.keras.layers.Activation('Softmax') + ) + ``` + + Alternatively, the policy can be passed to individual layers instead of + setting the global policy with `set_policy`: + + ```python + policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') + model = tf.keras.models.Sequential( + tf.keras.layers.Input((100,), dtype='float16'), + tf.keras.layers.Dense(10, dtype=policy), + tf.keras.layers.Dense(10, dtype=policy), + tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')), + tf.keras.layers.Activation('Softmax') + ) + ``` + """ + + def __init__(self, name): + self._name = name + if name == 'infer': + self._default_variable_dtype = None + elif name == 'infer_float32_vars': + self._default_variable_dtype = 'float32' + else: + raise ValueError('"name" argument to Policy constructor must be "infer" ' + 'or "infer_float32_vars", but got: %s' % name) + + @property + def name(self): + """Returns the name of the policy: "infer" or "infer_float32_vars.""" + return self._name + + @property + def default_variable_dtype(self): + """Returns the default variable dtype of this policy. + + This is the dtype layers will create their variables in, unless a layer + explicit chooses a different dtype. Layers will cast variables to the + appropriate dtype to avoid type errors. + + Returns: + The default variable dtype of this policy, or None if the default variable + dtype should be derived from the inputs. + """ + return self._default_variable_dtype + + @property + def should_cast_variables(self): + """Returns true if variables should be casted.""" + return self.default_variable_dtype is not None + + # TODO(reedwm): Implement get_config/from_config. + + +# TODO(reedwm): Make this thread local? +_global_policy = Policy('infer') + + +@keras_export('keras.mixed_precision.experimental.global_policy') +def global_policy(): + """Returns the global Policy. + + The global policy is the default policy used for layers, if no policy is + passed to the layer constructor. When TensorFlow starts, the global policy is + set to an "infer" policy, and can be changed with `set_policy`. + + Returns: + The global Policy. + """ + return _global_policy + + +@keras_export('keras.mixed_precision.experimental.set_policy') +def set_policy(policy): + """Sets the global Policy.""" + global _global_policy + if not isinstance(policy, Policy): + policy = Policy(policy) + _global_policy = policy + + +# TODO(reedwm): Make this thread local +@contextlib.contextmanager +def policy_scope(policy): + old_policy = _global_policy + try: + set_policy(policy) + yield + finally: + set_policy(old_policy) diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py new file mode 100644 index 00000000000..278f5211044 --- /dev/null +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -0,0 +1,69 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests Policies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class PolicyTest(test.TestCase): + """Tests Policies.""" + + def test_infer(self): + policy = mp_policy.Policy('infer') + self.assertEqual(policy.name, 'infer') + self.assertEqual(policy.default_variable_dtype, None) + + def test_infer_float32_vars(self): + policy = mp_policy.Policy('infer_float32_vars') + self.assertEqual(policy.name, 'infer_float32_vars') + self.assertEqual(policy.default_variable_dtype, 'float32') + + def test_global_policy(self): + self.assertEqual(mp_policy.global_policy().name, 'infer') + default_policy = mp_policy.global_policy() + try: + mp_policy.set_policy('infer_float32_vars') + self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') + self.assertEqual(mp_policy.global_policy().default_variable_dtype, + 'float32') + with ops.Graph().as_default(): # Policies are not associated with a graph + self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') + mp_policy.set_policy('infer') + self.assertEqual(mp_policy.global_policy().name, 'infer') + self.assertEqual(mp_policy.global_policy().default_variable_dtype, None) + policy = mp_policy.Policy('infer_float32_vars') + mp_policy.set_policy(policy) + self.assertIs(mp_policy.global_policy(), policy) + finally: + mp_policy.set_policy(default_policy) + + def test_policy_scope(self): + with mp_policy.policy_scope('infer_float32_vars'): + self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') + with mp_policy.policy_scope('infer'): + self.assertEqual(mp_policy.global_policy().name, 'infer') + self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') + self.assertEqual(mp_policy.global_policy().name, 'infer') + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index bb01048db9a..89652a0a4b8 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -307,7 +307,8 @@ class Layer(base_layer.Layer): use_resource=None, synchronization=vs.VariableSynchronization.AUTO, aggregation=vs.VariableAggregation.NONE, - partitioner=None): + partitioner=None, + **kwargs): """Adds a new variable to the layer, or gets an existing one; returns it. Arguments: @@ -342,6 +343,7 @@ class Layer(base_layer.Layer): `tf.variable_axis_size_partitioner`. For more details, see the documentation of `tf.get_variable` and the "Variable Partitioners and Sharding" section of the API guide. + **kwargs: Additional keyword arguments. Returns: The created variable. Usually either a `Variable` or `ResourceVariable` @@ -354,6 +356,9 @@ class Layer(base_layer.Layer): ValueError: When trainable has been set to True with synchronization set as `ON_READ`. """ + for kwarg in kwargs: + if kwarg != 'experimental_autocast': + raise TypeError('Unknown keyword argument:', kwarg) if self._keras_style: return super(Layer, self).add_weight( name=name, @@ -366,7 +371,8 @@ class Layer(base_layer.Layer): use_resource=use_resource, synchronization=vs.VariableSynchronization.AUTO, aggregation=vs.VariableAggregation.NONE, - partitioner=partitioner) + partitioner=partitioner, + **kwargs) if synchronization == vs.VariableSynchronization.ON_READ: if trainable: @@ -433,7 +439,8 @@ class Layer(base_layer.Layer): use_resource=use_resource, synchronization=synchronization, aggregation=aggregation, - getter=vs.get_variable) + getter=vs.get_variable, + **kwargs) if regularizer: if (ops.executing_eagerly_outside_functions() diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 07cea1143de..7d797f44862 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -90,6 +90,8 @@ KERAS_API_INIT_FILES = [ "keras/layers/experimental/__init__.py", "keras/losses/__init__.py", "keras/metrics/__init__.py", + "keras/mixed_precision/__init__.py", + "keras/mixed_precision/experimental/__init__.py", "keras/models/__init__.py", "keras/optimizers/__init__.py", "keras/optimizers/schedules/__init__.py", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index bed3b364ac7..7b13ad9ab9d 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -114,6 +114,8 @@ KERAS_API_INIT_FILES_V1 = [ "keras/layers/experimental/__init__.py", "keras/losses/__init__.py", "keras/metrics/__init__.py", + "keras/mixed_precision/__init__.py", + "keras/mixed_precision/experimental/__init__.py", "keras/models/__init__.py", "keras/optimizers/__init__.py", "keras/optimizers/schedules/__init__.py", diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt new file mode 100644 index 00000000000..a2af65554f7 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.mixed_precision.experimental.Policy" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "default_variable_dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "should_cast_variables" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.pbtxt new file mode 100644 index 00000000000..49fa92cab93 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.pbtxt @@ -0,0 +1,15 @@ +path: "tensorflow.keras.mixed_precision.experimental" +tf_module { + member { + name: "Policy" + mtype: "" + } + member_method { + name: "global_policy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_policy" + argspec: "args=[\'policy\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt new file mode 100644 index 00000000000..e8648afb5f7 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.mixed_precision" +tf_module { + member { + name: "experimental" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt index ed996785620..3db6920519a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt @@ -56,6 +56,10 @@ tf_module { name: "metrics" mtype: "" } + member { + name: "mixed_precision" + mtype: "" + } member { name: "models" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt index d012bd97efe..f8acd9b54e2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt index 90a27e5d66a..35aab02ed44 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt index d653a0cec41..0b44b231c3b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt index 32f9345ea40..86e0f46e148 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt index 0e7adfe26b6..2e11690019a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt index 5296597dc55..0be1478ea93 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -118,7 +118,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt index 5ae9568e642..7a6c6f2f2cd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt index aa0da6d68ca..4ba326546c6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -118,7 +118,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt index 516f0faea98..753a7965d79 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt index d92af8f3264..52624add063 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt @@ -116,7 +116,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt index 614643fc994..f412f2bade2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt @@ -116,7 +116,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt index 31022d3049e..e0e6f2849ae 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt @@ -116,7 +116,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt index 03bbf39022d..903809b243d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt index 63a301e3e6e..badd5d7b973 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt index d81a3368ced..4076962adfd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt index 48d93d503e8..ee591be46fa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt index 2f1f1c1e3fd..c837bf4d4f7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt @@ -118,7 +118,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt index bd7549af4c4..72b2c446522 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt @@ -118,7 +118,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt index 34b4133d0ca..c84513d0885 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt index 2fff2b8606c..269944ee9df 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index 95136152775..f2211375ed4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 912f78fac15..ca6e923ebcf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 58d004b3d5d..99b1a4e509b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index a7b63a7c2b4..8c3a9d7fe53 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 3f17805af25..ce25e44b17b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 055485f3e90..1859a9e388b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 23272f44227..baf6278bac9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index a9f7e85b148..da4d5c47aae 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -123,7 +123,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index ecf43616741..5772213c3e3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt new file mode 100644 index 00000000000..a2af65554f7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.keras.mixed_precision.experimental.Policy" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "default_variable_dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "should_cast_variables" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.pbtxt new file mode 100644 index 00000000000..49fa92cab93 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.pbtxt @@ -0,0 +1,15 @@ +path: "tensorflow.keras.mixed_precision.experimental" +tf_module { + member { + name: "Policy" + mtype: "" + } + member_method { + name: "global_policy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_policy" + argspec: "args=[\'policy\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt new file mode 100644 index 00000000000..e8648afb5f7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.mixed_precision" +tf_module { + member { + name: "experimental" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt index ed996785620..3db6920519a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt @@ -56,6 +56,10 @@ tf_module { name: "metrics" mtype: "" } + member { + name: "mixed_precision" + mtype: "" + } member { name: "models" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt index 34b4133d0ca..c84513d0885 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt index 2fff2b8606c..269944ee9df 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt @@ -125,7 +125,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-dropout-wrapper.pbtxt index 20e5601740f..6d4b429a982 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-dropout-wrapper.pbtxt @@ -132,7 +132,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.-residual-wrapper.pbtxt index 1904f5c0654..406aacbc229 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.-residual-wrapper.pbtxt @@ -128,7 +128,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 58d004b3d5d..99b1a4e509b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply"