Add support for sample_weights for Keras + tf.distribute codepaths.
PiperOrigin-RevId: 251066253
This commit is contained in:
parent
270305c6d5
commit
98a0c57c44
@ -215,6 +215,14 @@ def get_model():
|
||||
return model
|
||||
|
||||
|
||||
def get_sample_weights_model():
|
||||
x = keras.layers.Input(shape=(1,), name='input')
|
||||
y = keras.layers.Dense(
|
||||
1, kernel_initializer='ones', bias_initializer='zeros', name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
return model
|
||||
|
||||
|
||||
def get_dataset(distribution):
|
||||
inputs = np.zeros((10, 3), dtype=np.float32)
|
||||
targets = np.zeros((10, 4), dtype=np.float32)
|
||||
@ -262,8 +270,13 @@ def multi_input_output_model():
|
||||
return model
|
||||
|
||||
|
||||
# TODO(josh11b): Add combinations.one_device_strategy_gpu once it works with
|
||||
# TestDistributionStrategyWithCallbacks.test_callbacks_in_predict.
|
||||
strategies_minus_default_minus_tpu = [
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus
|
||||
]
|
||||
|
||||
strategies_minus_tpu = [
|
||||
strategy_combinations.default_strategy,
|
||||
strategy_combinations.one_device_strategy,
|
||||
@ -673,29 +686,58 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
model.predict(inputs, batch_size=8)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategies_minus_tpu,
|
||||
mode=['graph'],
|
||||
cloning=[True, False]))
|
||||
def test_numpy_with_sample_weights(self, distribution, cloning):
|
||||
with self.cached_session():
|
||||
with distribution.scope():
|
||||
# TODO(b/130808953): Re-enable the V1 optimizer after iterations is
|
||||
# mirrored.
|
||||
optimizer_fn = (
|
||||
rmsprop.RMSPropOptimizer
|
||||
if cloning else gradient_descent_keras.SGD)
|
||||
optimizer = optimizer_fn(learning_rate=0.001)
|
||||
model = get_model()
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss, cloning=cloning)
|
||||
combinations.combine(distribution=strategies_minus_tpu,
|
||||
mode=['graph', 'eager']))
|
||||
def test_numpy_with_sample_weights(self, distribution):
|
||||
with self.cached_session(), distribution.scope():
|
||||
model = get_sample_weights_model()
|
||||
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss)
|
||||
|
||||
inputs = np.zeros((20, 3), np.float32)
|
||||
targets = np.zeros((20, 4), np.float32)
|
||||
sample_weights = np.ones((20), np.float32)
|
||||
inputs = np.array([[0], [1], [2], [3]], np.float32)
|
||||
targets = np.array([[2], [4], [6], [8]], np.float32)
|
||||
sample_weights = np.array([0.25, 0.5, 0.75, 1], np.float32)
|
||||
|
||||
model.fit(inputs, targets, sample_weight=sample_weights, epochs=1,
|
||||
steps_per_epoch=2, verbose=1)
|
||||
result = model.evaluate(inputs, targets, batch_size=2,
|
||||
sample_weight=sample_weights, verbose=1)
|
||||
# The per sample loss is multipled by the corresponding sample weight. The
|
||||
# average of these weighted losses is the return value of the `evaluate`
|
||||
# call. For example, in the test above the average weighted loss is
|
||||
# calculated in the following manner:
|
||||
# batch_1 = (((2-0)^2) * 0.25 + ((4-1)^2) * 0.5) / 2 = 5.5 / 2 = 2.75
|
||||
# batch_2 = (((6-2)^2 * 0.75) + ((8-3)^2 * 1)) / 2 = 37 / 2 = 18.5
|
||||
# final result = (batch_1 + batch_2) / 2 = 10.625.
|
||||
# The first time we divide by number of input samples and the second time
|
||||
# we divide by number of steps/batches that the loss is aggregated over.
|
||||
self.assertAllClose(result, 10.625)
|
||||
|
||||
# We now test without passing sample_weights:
|
||||
# batch_1 = ((2-0)^2) + ((4-1)^2) / 2 = 13 / 2 = 6.5
|
||||
# batch_2 = ((6-2)^2) + ((8-3)^2) / 2 = 41 / 2 = 20.5
|
||||
# final result = (batch_1 + batch_2) / 2 = 27 / 2 = 13.5
|
||||
result = model.evaluate(inputs, targets, batch_size=2, verbose=1)
|
||||
self.assertAllClose(result, 13.5)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=strategies_minus_default_minus_tpu,
|
||||
mode=['eager']))
|
||||
def test_numpy_with_sample_weights_eager_with_cloning(self, distribution):
|
||||
with self.cached_session(), distribution.scope():
|
||||
model = get_sample_weights_model()
|
||||
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss, cloning=True)
|
||||
|
||||
inputs = np.array([[0], [1], [2], [3]], np.float32)
|
||||
targets = np.array([[2], [4], [6], [8]], np.float32)
|
||||
sample_weights = np.array([0.25, 0.5, 0.75, 1], np.float32)
|
||||
|
||||
with self.assertRaisesRegexp(NotImplementedError,
|
||||
'`sample_weight` is not supported when '
|
||||
'using tf.distribute.Strategy in '):
|
||||
model.evaluate(inputs, targets, batch_size=2,
|
||||
sample_weight=sample_weights, verbose=1)
|
||||
|
||||
@combinations.generate(all_strategy_combinations_plus_cloning())
|
||||
def test_flatten_predict_outputs(self, distribution, cloning):
|
||||
@ -1162,34 +1204,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
model.evaluate(dataset, steps=2, verbose=1)
|
||||
model.predict(get_predict_dataset(distribution), steps=2)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(strategy_minus_tpu_combinations(),
|
||||
combinations.combine(cloning=[True, False])))
|
||||
def test_dataset_with_sample_weights(self, distribution, cloning):
|
||||
with self.cached_session():
|
||||
with distribution.scope():
|
||||
model = get_model()
|
||||
# TODO(b/130808953): Re-enable the V1 optimizer after iterations is
|
||||
# mirrored.
|
||||
optimizer_fn = (
|
||||
rmsprop.RMSPropOptimizer
|
||||
if cloning else gradient_descent_keras.SGD)
|
||||
optimizer = optimizer_fn(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss, cloning=cloning)
|
||||
|
||||
inputs = np.zeros((10, 3), np.float32)
|
||||
targets = np.zeros((10, 4), np.float32)
|
||||
sample_weights = np.ones((10), np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
|
||||
sample_weights))
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.batch(10)
|
||||
|
||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
|
||||
model.evaluate(dataset, steps=2, verbose=1)
|
||||
model.predict(dataset, steps=2)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
@ -1483,6 +1497,61 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
atol=1e-4,
|
||||
rtol=1e-4)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=strategies_minus_tpu,
|
||||
mode=['graph', 'eager']))
|
||||
def test_dataset_with_sample_weights(self, distribution):
|
||||
with self.cached_session(), distribution.scope():
|
||||
model = get_sample_weights_model()
|
||||
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss)
|
||||
|
||||
inputs = np.array([[0], [1], [2], [3]], np.float32)
|
||||
targets = np.array([[2], [4], [6], [8]], np.float32)
|
||||
sample_weights = np.array([0.25, 0.5, 0.75, 1], np.float32)
|
||||
ds = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
|
||||
sample_weights)).batch(2)
|
||||
result = model.evaluate(ds, verbose=1)
|
||||
# The per sample loss is multipled by the corresponding sample weight. The
|
||||
# average of these weighted losses is the return value of the `evaluate`
|
||||
# call. For example, in the test above the average weighted loss is
|
||||
# calculated in the following manner:
|
||||
# batch_1 = (((2-0)^2) * 0.25 + ((4-1)^2) * 0.5) / 2 = 5.5 / 2 = 2.75
|
||||
# batch_2 = (((6-2)^2 * 0.75) + ((8-3)^2 * 1)) / 2 = 37 / 2 = 18.5
|
||||
# final result = (batch_1 + batch_2) / 2 = 10.625.
|
||||
# The first time we divide by number of input samples and the second time
|
||||
# we divide by number of steps/batches that the loss is aggregated over.
|
||||
self.assertAllClose(result, 10.625)
|
||||
|
||||
# We now test without passing sample_weights:
|
||||
# batch_1 = ((2-0)^2) + ((4-1)^2) / 2 = 13 / 2 = 6.5
|
||||
# batch_2 = ((6-2)^2) + ((8-3)^2) / 2 = 41 / 2 = 20.5
|
||||
# final result = (batch_1 + batch_2) / 2 = 27 / 2 = 13.5
|
||||
ds = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(2)
|
||||
result = model.evaluate(ds, verbose=1)
|
||||
self.assertAllClose(result, 13.5)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=strategies_minus_default_minus_tpu,
|
||||
mode=['eager']))
|
||||
def test_dataset_with_sample_weights_eager_with_cloning(self, distribution):
|
||||
with self.cached_session(), distribution.scope():
|
||||
model = get_sample_weights_model()
|
||||
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss, cloning=True)
|
||||
|
||||
inputs = np.array([[0], [1], [2], [3]], np.float32)
|
||||
targets = np.array([[2], [4], [6], [8]], np.float32)
|
||||
sample_weights = np.array([0.25, 0.5, 0.75, 1], np.float32)
|
||||
ds = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
|
||||
sample_weights)).batch(2)
|
||||
|
||||
with self.assertRaisesRegexp(NotImplementedError,
|
||||
'`sample_weight` is not supported when '
|
||||
'using tf.distribute.Strategy in '):
|
||||
model.evaluate(ds, verbose=1)
|
||||
|
||||
class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
|
||||
class IdentityRegularizer(keras.regularizers.Regularizer):
|
||||
|
@ -596,12 +596,13 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
||||
if mode == ModeKeys.PREDICT:
|
||||
sample_weights = []
|
||||
targets = []
|
||||
elif not is_distributing_by_cloning(model):
|
||||
sample_weights = None # b/129503665
|
||||
else:
|
||||
sample_weights = [
|
||||
None for _ in range(len(model.outputs) * strategy.num_replicas_in_sync)
|
||||
]
|
||||
elif sample_weights is not None and is_distributing_by_cloning(model):
|
||||
if context.executing_eagerly() and not model._compile_distribution:
|
||||
raise NotImplementedError('`sample_weight` is not supported when using '
|
||||
'tf.distribute.Strategy in eager mode and '
|
||||
'cloning=True.')
|
||||
sample_weights = flatten_per_replica_values(strategy, sample_weights)
|
||||
|
||||
ins = [inputs, targets, sample_weights]
|
||||
return tuple(ins)
|
||||
|
||||
@ -835,25 +836,36 @@ def _make_replica_execution_function(model, mode):
|
||||
return func
|
||||
|
||||
|
||||
def _make_execution_function_with_cloning(model, mode):
|
||||
"""Clones or re-uses models to run one step of distributed model execution."""
|
||||
def _make_replicated_models_with_cloning(model, mode):
|
||||
"""Build models on each replica."""
|
||||
strategy = model._distribution_strategy
|
||||
|
||||
distributed_model = get_distributed_model(model, mode)
|
||||
# If distributed model for a particular `mode` is already built, use the
|
||||
# `_distribution_function` on that distributed model.
|
||||
if distributed_model:
|
||||
return distributed_model._distributed_function
|
||||
|
||||
# If distributed_model is not built, create one for `mode`.
|
||||
if model._compile_distribution:
|
||||
clone_model_on_replicas(model, strategy, mode)
|
||||
else:
|
||||
_build_distributed_network(model, strategy, mode)
|
||||
|
||||
# We've just created the distributed model. So `distributed_model` should be
|
||||
# not None.
|
||||
|
||||
def _make_execution_function_with_cloning(model, mode):
|
||||
"""Clones or re-uses models to run one step of distributed model execution."""
|
||||
distributed_model = get_distributed_model(model, mode)
|
||||
# TODO(b/134069401): Create a cache for the distributed model and exec
|
||||
# function that incorporates additional attributes to be part of the cache key
|
||||
# than just the mode.
|
||||
# If distributed model for a particular `mode` is already built, use the
|
||||
# `_distribution_function` on that distributed model.
|
||||
# If you have updated the sample_weight_mode on the model, then you will need
|
||||
# to recompile metrics and recreate the execution function. This is indicated
|
||||
# by the `_recompile_exec_function` property.
|
||||
if (distributed_model and hasattr(distributed_model, '_distribution_function')
|
||||
and not (hasattr(distributed_model, '_recompile_exec_function') and
|
||||
distributed_model._recompile_exec_function)):
|
||||
return distributed_model._distributed_function
|
||||
|
||||
if not distributed_model:
|
||||
_make_replicated_models_with_cloning(model, mode)
|
||||
distributed_model = get_distributed_model(model, mode)
|
||||
assert distributed_model
|
||||
|
||||
# Also create an execution fuction on that distributed model.
|
||||
@ -865,6 +877,7 @@ def _make_execution_function_with_cloning(model, mode):
|
||||
# We cache the distributed execution function on the model since creating
|
||||
# distributed models and exection functions are expensive.
|
||||
distributed_model._distributed_function = distributed_function
|
||||
distributed_model._recompile_exec_function = False
|
||||
return distributed_function
|
||||
|
||||
|
||||
@ -1088,3 +1101,24 @@ def filter_distributed_callbacks(callbacks_list):
|
||||
return [
|
||||
callback for callback in callbacks_list if not callback._chief_worker_only
|
||||
] # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _update_sample_weight_modes(model, mode, sample_weights):
|
||||
"""Update sample_weight_mode of the distributed model."""
|
||||
if is_distributing_by_cloning(model):
|
||||
distributed_model = get_distributed_model(model, mode)
|
||||
if not distributed_model:
|
||||
_make_replicated_models_with_cloning(model, mode)
|
||||
distributed_model = get_distributed_model(model, mode)
|
||||
distributed_model._recompile_exec_function = any(
|
||||
[e.sample_weights_mismatch() for e in model._training_endpoints])
|
||||
|
||||
if sample_weights:
|
||||
distributed_models = flatten_per_replica_values(
|
||||
model._distribution_strategy, distributed_model)
|
||||
# sample_weights is a tuple of 1 list where the number of elements in the
|
||||
# list is equal to the number of replicas in sync.
|
||||
sample_weights = sample_weights[0]
|
||||
if sample_weights and None not in sample_weights:
|
||||
for m, sw in zip(distributed_models, sample_weights):
|
||||
m._update_sample_weight_modes(sample_weights=[sw])
|
||||
|
@ -557,11 +557,17 @@ def _make_execution_function(model, mode):
|
||||
|
||||
def _update_sample_weight_mode(model, mode, inputs):
|
||||
"""Updates the sample_weight_mode of a given model."""
|
||||
if not model._distribution_strategy and mode != ModeKeys.PREDICT:
|
||||
# `inputs` is the model's inputs + targets + sample_weights +
|
||||
# learning phase placeholder if specified. To update the sample_weight_mode
|
||||
# we need to determine if the user has passed sample weights as part of the
|
||||
# input.
|
||||
# Add a quick return to prevent us from calling model._feed_targets that
|
||||
# accesses certain model properties that may not be set in the `PREDICT` mode.
|
||||
if mode == ModeKeys.PREDICT:
|
||||
return
|
||||
|
||||
sample_weights = None
|
||||
# `inputs` is the model's inputs + targets + sample_weights +
|
||||
# learning phase placeholder if specified. To update the sample_weight_mode
|
||||
# we need to determine if the user has passed sample weights as part of the
|
||||
# input.
|
||||
if not callable(inputs):
|
||||
sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):]
|
||||
has_learning_phase_pl = (mode == ModeKeys.TRAIN and
|
||||
not isinstance(K.symbolic_learning_phase(), int))
|
||||
@ -569,6 +575,12 @@ def _update_sample_weight_mode(model, mode, inputs):
|
||||
sample_weights = sample_weights[:-1]
|
||||
model._update_sample_weight_modes(sample_weights=sample_weights)
|
||||
|
||||
# Call the DistributionStrategy specific function to update the
|
||||
# sample_weight_mode on the model.
|
||||
if model._distribution_strategy:
|
||||
distributed_training_utils._update_sample_weight_modes(model, mode,
|
||||
sample_weights)
|
||||
|
||||
# For backwards compatibility for internal users of these loops.
|
||||
fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
|
||||
test_loop = functools.partial(
|
||||
|
Loading…
Reference in New Issue
Block a user