Add support for aggregating batch statistics across devices by using the newly added tf.keras.layers.experimental.SyncBatchNormalization layer.
PiperOrigin-RevId: 292723222 Change-Id: I1c0458ec24c7e712ffa5e12dcf1f5efd6b4ce8ac
This commit is contained in:
parent
319b73c629
commit
adf769043f
@ -35,7 +35,7 @@ from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
_NUM_SAMPLES = 64
|
||||
_NUM_SAMPLES = 66
|
||||
_BATCH_SIZE = 32
|
||||
_RANDOM_SEED = 1337
|
||||
_NUM_EPOCHS = 2
|
||||
@ -60,12 +60,16 @@ class MaybeStrategyScope(object):
|
||||
self._scope = None
|
||||
|
||||
|
||||
def get_model():
|
||||
def get_model(sync_batchnorm=False):
|
||||
model = keras.Sequential()
|
||||
model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,)))
|
||||
model.add(keras.layers.Dense(
|
||||
10, activation='relu',
|
||||
kernel_regularizer=keras.regularizers.l2(1e-4)))
|
||||
if sync_batchnorm:
|
||||
model.add(keras.layers.SyncBatchNormalization())
|
||||
else:
|
||||
model.add(keras.layers.BatchNormalization())
|
||||
model.add(keras.layers.Dense(10, activation='relu'))
|
||||
model.add(keras.layers.Dense(1))
|
||||
return model
|
||||
@ -90,10 +94,13 @@ def compute_loss(labels, logits, reg_losses):
|
||||
|
||||
|
||||
def iteration_inside_func(initial_weights, dataset, optimizer_fn,
|
||||
iteration_type, strategy=None):
|
||||
iteration_type, strategy=None, sync_batchnorm=None):
|
||||
"""Helper function to test iterating over data inside a tf.function."""
|
||||
with MaybeStrategyScope(strategy):
|
||||
model = get_model()
|
||||
if strategy and sync_batchnorm:
|
||||
model = get_model(sync_batchnorm)
|
||||
else:
|
||||
model = get_model()
|
||||
model.set_weights(initial_weights)
|
||||
optimizer = optimizer_fn()
|
||||
|
||||
@ -153,10 +160,10 @@ def iteration_inside_func(initial_weights, dataset, optimizer_fn,
|
||||
|
||||
|
||||
def iteration_outside_func(initial_weights, dataset, optimizer_fn,
|
||||
iteration_type, strategy=None):
|
||||
iteration_type, strategy=None, sync_batchnorm=None):
|
||||
"""Helper function to test iterating over data outside a tf.function."""
|
||||
with MaybeStrategyScope(strategy):
|
||||
model = get_model()
|
||||
model = get_model(sync_batchnorm=sync_batchnorm)
|
||||
model.set_weights(initial_weights)
|
||||
optimizer = optimizer_fn()
|
||||
|
||||
@ -223,16 +230,21 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
||||
optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
|
||||
mode=['eager'],
|
||||
iteration_type=['iterator', 'dataset'],
|
||||
inside_func=[False, True]
|
||||
inside_func=[False, True],
|
||||
sync_batchnorm=[True, False]
|
||||
))
|
||||
def test_dnn_correctness_minus_tpus(self, distribution, optimizer_fn,
|
||||
iteration_type, inside_func):
|
||||
iteration_type, inside_func,
|
||||
sync_batchnorm):
|
||||
# TODO(anjs): Identify why this particular V1 optimizer needs a higher tol.
|
||||
if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__:
|
||||
self.skipTest('Reduced tolerance of the order of 1e-1 required.')
|
||||
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
|
||||
inside_func)
|
||||
inside_func, sync_batchnorm)
|
||||
|
||||
def dnn_correctness(self, distribution, optimizer_fn, iteration_type,
|
||||
inside_func):
|
||||
model = get_model()
|
||||
inside_func, sync_batchnorm=None):
|
||||
model = get_model(sync_batchnorm)
|
||||
initial_weights = model.get_weights()
|
||||
dataset = get_data()
|
||||
if inside_func:
|
||||
@ -241,13 +253,15 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
||||
iteration_func = iteration_outside_func
|
||||
wts_with_ds, loss_with_ds, acc_with_ds = iteration_func(
|
||||
initial_weights, dataset, optimizer_fn, iteration_type,
|
||||
strategy=distribution)
|
||||
strategy=distribution, sync_batchnorm=sync_batchnorm)
|
||||
wts, loss, acc = iteration_func(initial_weights, dataset, optimizer_fn,
|
||||
iteration_type)
|
||||
iteration_type,
|
||||
sync_batchnorm=sync_batchnorm)
|
||||
|
||||
self.assertAllClose(wts, wts_with_ds, atol=1e-3, rtol=1e-3)
|
||||
self.assertAllClose(loss, loss_with_ds, atol=1e-3, rtol=1e-3)
|
||||
self.assertAllClose(acc, acc_with_ds, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -386,7 +386,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
||||
def set_up_test_config(self,
|
||||
use_numpy=False,
|
||||
use_validation_data=False,
|
||||
with_batch_norm=False):
|
||||
with_batch_norm=None):
|
||||
self.use_numpy = use_numpy
|
||||
self.use_validation_data = use_validation_data
|
||||
self.with_batch_norm = with_batch_norm
|
||||
@ -435,7 +435,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
||||
use_numpy,
|
||||
use_validation_data,
|
||||
experimental_run_tf_function=None,
|
||||
with_batch_norm=False,
|
||||
with_batch_norm=None,
|
||||
is_stateful_model=False,
|
||||
partial_last_batch=None,
|
||||
training_epochs=2):
|
||||
@ -503,7 +503,8 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
||||
# First, special case, for multi-replica distributed training, batch
|
||||
# norm is not aggregated globally. So it is expected to have different
|
||||
# weights.
|
||||
if (self.with_batch_norm and distribution.num_replicas_in_sync > 1):
|
||||
if (self.with_batch_norm == 'regular' and
|
||||
distribution.num_replicas_in_sync > 1):
|
||||
with self.assertRaises(AssertionError):
|
||||
compare_results(
|
||||
results_with_ds,
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
@ -43,8 +44,10 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
strides=(4, 4),
|
||||
kernel_regularizer=keras.regularizers.l2(1e-4))(
|
||||
image)
|
||||
if self.with_batch_norm:
|
||||
if self.with_batch_norm == 'regular':
|
||||
c1 = keras.layers.BatchNormalization(name='bn1')(c1)
|
||||
elif self.with_batch_norm == 'sync':
|
||||
c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
|
||||
c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
|
||||
logits = keras.layers.Dense(
|
||||
10, activation='softmax', name='pred')(
|
||||
@ -107,7 +110,22 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
distribution,
|
||||
use_numpy,
|
||||
use_validation_data,
|
||||
with_batch_norm=True,
|
||||
with_batch_norm='regular',
|
||||
experimental_run_tf_function=experimental_run_tf_function)
|
||||
|
||||
@combinations.generate(
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||
def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
|
||||
use_validation_data,
|
||||
experimental_run_tf_function):
|
||||
if not context.executing_eagerly() or not experimental_run_tf_function:
|
||||
self.skipTest('SyncBatchNorm is not enabled in graph mode.')
|
||||
|
||||
self.run_correctness_test(
|
||||
distribution,
|
||||
use_numpy,
|
||||
use_validation_data,
|
||||
with_batch_norm='sync',
|
||||
experimental_run_tf_function=experimental_run_tf_function)
|
||||
|
||||
@combinations.generate(
|
||||
@ -134,7 +152,7 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
distribution,
|
||||
use_numpy,
|
||||
use_validation_data,
|
||||
with_batch_norm=True,
|
||||
with_batch_norm='regular',
|
||||
partial_last_batch=True)
|
||||
|
||||
|
||||
|
@ -135,6 +135,8 @@ from tensorflow.python.keras.layers.noise import GaussianDropout
|
||||
|
||||
# Normalization layers.
|
||||
from tensorflow.python.keras.layers.normalization import LayerNormalization
|
||||
from tensorflow.python.keras.layers.normalization_v2 import SyncBatchNormalization
|
||||
|
||||
if tf2.enabled():
|
||||
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization
|
||||
from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1
|
||||
|
@ -652,8 +652,12 @@ class BatchNormalizationBase(Layer):
|
||||
|
||||
return (r, d, out_mean, out_variance)
|
||||
|
||||
def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims):
|
||||
return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
|
||||
|
||||
def _moments(self, inputs, reduction_axes, keep_dims):
|
||||
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
|
||||
mean, variance = self._calculate_mean_and_var(inputs, reduction_axes,
|
||||
keep_dims)
|
||||
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
|
||||
# code as well.
|
||||
if self._support_zero_size_input():
|
||||
|
@ -18,10 +18,192 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras.layers import normalization
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[]) # pylint: disable=g-classes-have-attributes
|
||||
class SyncBatchNormalization(normalization.BatchNormalizationBase):
|
||||
r"""Normalize and scale inputs or activations synchronously across replicas.
|
||||
|
||||
Applies batch normalization to activations of the previous layer at each batch
|
||||
by synchronizing the global batch statistics across all devices that are
|
||||
training the model. For specific details about batch normalization please
|
||||
refer to the `tf.keras.layers.BatchNormalization` layer docs.
|
||||
|
||||
If this layer is used when using tf.distribute strategy to train models
|
||||
across devices/workers, there will be an allreduce call to aggregate batch
|
||||
statistics across all replicas at every training step. Without tf.distribute
|
||||
strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization`
|
||||
layer.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
|
||||
with strategy.scope():
|
||||
model = tf.keras.Sequential()
|
||||
model.add(tf.keras.layers.Dense(16))
|
||||
model.add(tf.keras.layers.experimental.SyncBatchNormalization())
|
||||
```
|
||||
|
||||
Arguments:
|
||||
axis: Integer, the axis that should be normalized
|
||||
(typically the features axis).
|
||||
For instance, after a `Conv2D` layer with
|
||||
`data_format="channels_first"`,
|
||||
set `axis=1` in `BatchNormalization`.
|
||||
momentum: Momentum for the moving average.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
center: If True, add offset of `beta` to normalized tensor.
|
||||
If False, `beta` is ignored.
|
||||
scale: If True, multiply by `gamma`.
|
||||
If False, `gamma` is not used.
|
||||
When the next layer is linear (also e.g. `nn.relu`),
|
||||
this can be disabled since the scaling
|
||||
will be done by the next layer.
|
||||
beta_initializer: Initializer for the beta weight.
|
||||
gamma_initializer: Initializer for the gamma weight.
|
||||
moving_mean_initializer: Initializer for the moving mean.
|
||||
moving_variance_initializer: Initializer for the moving variance.
|
||||
beta_regularizer: Optional regularizer for the beta weight.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||
beta_constraint: Optional constraint for the beta weight.
|
||||
gamma_constraint: Optional constraint for the gamma weight.
|
||||
renorm: Whether to use Batch Renormalization
|
||||
(https://arxiv.org/abs/1702.03275). This adds extra variables during
|
||||
training. The inference is the same for either value of this parameter.
|
||||
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
|
||||
scalar `Tensors` used to clip the renorm correction. The correction
|
||||
`(r, d)` is used as `corrected_value = normalized_value * r + d`, with
|
||||
`r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
|
||||
dmax are set to inf, 0, inf, respectively.
|
||||
renorm_momentum: Momentum used to update the moving means and standard
|
||||
deviations with renorm. Unlike `momentum`, this affects training
|
||||
and should be neither too small (which would add noise) nor too large
|
||||
(which would give stale estimates). Note that `momentum` is still applied
|
||||
to get the means and variances for inference.
|
||||
trainable: Boolean, if `True` the variables will be marked as trainable.
|
||||
|
||||
Call arguments:
|
||||
inputs: Input tensor (of any rank).
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode or in inference mode.
|
||||
- `training=True`: The layer will normalize its inputs using the
|
||||
mean and variance of the current batch of inputs.
|
||||
- `training=False`: The layer will normalize its inputs using the
|
||||
mean and variance of its moving statistics, learned during training.
|
||||
|
||||
Input shape:
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
|
||||
Output shape:
|
||||
Same shape as input.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
axis=-1,
|
||||
momentum=0.99,
|
||||
epsilon=1e-3,
|
||||
center=True,
|
||||
scale=True,
|
||||
beta_initializer='zeros',
|
||||
gamma_initializer='ones',
|
||||
moving_mean_initializer='zeros',
|
||||
moving_variance_initializer='ones',
|
||||
beta_regularizer=None,
|
||||
gamma_regularizer=None,
|
||||
beta_constraint=None,
|
||||
gamma_constraint=None,
|
||||
renorm=False,
|
||||
renorm_clipping=None,
|
||||
renorm_momentum=0.99,
|
||||
trainable=True,
|
||||
adjustment=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
|
||||
# Currently we only support aggregating over the global batch size.
|
||||
super(SyncBatchNormalization, self).__init__(
|
||||
axis=axis,
|
||||
momentum=momentum,
|
||||
epsilon=epsilon,
|
||||
center=center,
|
||||
scale=scale,
|
||||
beta_initializer=beta_initializer,
|
||||
gamma_initializer=gamma_initializer,
|
||||
moving_mean_initializer=moving_mean_initializer,
|
||||
moving_variance_initializer=moving_variance_initializer,
|
||||
beta_regularizer=beta_regularizer,
|
||||
gamma_regularizer=gamma_regularizer,
|
||||
beta_constraint=beta_constraint,
|
||||
gamma_constraint=gamma_constraint,
|
||||
renorm=renorm,
|
||||
renorm_clipping=renorm_clipping,
|
||||
renorm_momentum=renorm_momentum,
|
||||
fused=False,
|
||||
trainable=trainable,
|
||||
virtual_batch_size=None,
|
||||
name=name,
|
||||
**kwargs)
|
||||
|
||||
def _calculate_mean_and_var(self, x, axes, keep_dims):
|
||||
|
||||
with ops.name_scope('moments', values=[x, axes]):
|
||||
# The dynamic range of fp16 is too limited to support the collection of
|
||||
# sufficient statistics. As a workaround we simply perform the operations
|
||||
# on 32-bit floats before converting the mean and variance back to fp16
|
||||
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
||||
replica_ctx = ds.get_replica_context()
|
||||
if replica_ctx:
|
||||
local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
|
||||
local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes,
|
||||
keepdims=True)
|
||||
y_sum, y_squared_sum, global_batch_size = (
|
||||
replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [
|
||||
local_sum, local_squared_sum, array_ops.shape_v2(y)[0]]))
|
||||
|
||||
axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
|
||||
multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
|
||||
dtypes.float32)
|
||||
multiplier = multiplier * math_ops.cast(global_batch_size,
|
||||
dtypes.float32)
|
||||
|
||||
mean = y_sum / multiplier
|
||||
y_squared_mean = y_squared_sum / multiplier
|
||||
# var = E(x^2) - E(x)^2
|
||||
variance = y_squared_mean - math_ops.square(mean)
|
||||
else:
|
||||
# Compute true mean while keeping the dims for proper broadcasting.
|
||||
mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
|
||||
# sample variance, not unbiased variance
|
||||
# Note: stop_gradient does not change the gradient that gets
|
||||
# backpropagated to the mean from the variance calculation,
|
||||
# because that gradient is zero
|
||||
variance = math_ops.reduce_mean(
|
||||
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
|
||||
axes,
|
||||
keepdims=True,
|
||||
name='variance')
|
||||
if not keep_dims:
|
||||
mean = array_ops.squeeze(mean, axes)
|
||||
variance = array_ops.squeeze(variance, axes)
|
||||
if x.dtype == dtypes.float16:
|
||||
return (math_ops.cast(mean, dtypes.float16),
|
||||
math_ops.cast(variance, dtypes.float16))
|
||||
else:
|
||||
return (mean, variance)
|
||||
|
||||
|
||||
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
|
||||
class BatchNormalization(normalization.BatchNormalizationBase):
|
||||
|
||||
|
@ -0,0 +1,218 @@
|
||||
path: "tensorflow.keras.layers.experimental.SyncBatchNormalization"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization_v2.SyncBatchNormalization\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "inbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "metrics"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "outbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "submodules"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'trainable\', \'adjustment\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'zeros\', \'ones\', \'zeros\', \'ones\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_metric"
|
||||
argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_signature"
|
||||
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_name_scope"
|
||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.keras.layers.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "SyncBatchNormalization"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "preprocessing"
|
||||
mtype: "<type \'module\'>"
|
||||
|
Loading…
x
Reference in New Issue
Block a user