Synchronize stamp token between different resources:

* During bias_centering, the global_step on per-feature accumulators are not updated, so once the bias centering is done, we now update the accumulators to have in-sync global step.
* When a parameter server is restarted, the accumulators need to be synchronized with the stamp token after checkpoint is reloaded.

PiperOrigin-RevId: 209617008
This commit is contained in:
A. Unique TensorFlower 2018-08-21 10:25:05 -07:00 committed by TensorFlower Gardener
parent 212d978a2d
commit 792a933b11
2 changed files with 166 additions and 102 deletions

View File

@ -404,18 +404,21 @@ class _EnsembleGrower(object):
training_ops.append(grow_op) training_ops.append(grow_op)
""" """
def __init__(self, tree_ensemble, tree_hparams): def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
"""Initializes a grower object. """Initializes a grower object.
Args: Args:
tree_ensemble: A TreeEnsemble variable. tree_ensemble: A TreeEnsemble variable.
tree_hparams: TODO. collections.namedtuple for hyper parameters. tree_hparams: TODO. collections.namedtuple for hyper parameters.
feature_ids_list: a list of lists of feature ids for each bucket size.
Raises: Raises:
ValueError: when pruning mode is invalid or pruning is used and no tree ValueError: when pruning mode is invalid or pruning is used and no tree
complexity is set. complexity is set.
""" """
self._tree_ensemble = tree_ensemble self._tree_ensemble = tree_ensemble
self._tree_hparams = tree_hparams self._tree_hparams = tree_hparams
self._feature_ids_list = feature_ids_list
# pylint: disable=protected-access # pylint: disable=protected-access
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str( self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode) tree_hparams.pruning_mode)
@ -440,14 +443,12 @@ class _EnsembleGrower(object):
""" """
@abc.abstractmethod @abc.abstractmethod
def grow_tree(self, stats_summaries_list, feature_ids_list, def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
last_layer_nodes_range):
"""Grows a tree, if ready, based on provided statistics. """Grows a tree, if ready, based on provided statistics.
Args: Args:
stats_summaries_list: List of stats summary tensors, representing sums of stats_summaries_list: List of stats summary tensors, representing sums of
gradients and hessians for each feature bucket. gradients and hessians for each feature bucket.
feature_ids_list: a list of lists of feature ids for each bucket size.
last_layer_nodes_range: A tensor representing ids of the nodes in the last_layer_nodes_range: A tensor representing ids of the nodes in the
current layer, to be split. current layer, to be split.
@ -455,6 +456,10 @@ class _EnsembleGrower(object):
An op for growing a tree. An op for growing a tree.
""" """
def chief_init_op(self):
"""Ops that chief needs to run to initialize the state."""
return control_flow_ops.no_op()
# ============= Helper methods =========== # ============= Helper methods ===========
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians): def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
@ -468,7 +473,7 @@ class _EnsembleGrower(object):
return center_bias_var.assign(continue_centering) return center_bias_var.assign(continue_centering)
def _grow_tree_from_stats_summaries(self, stats_summaries_list, def _grow_tree_from_stats_summaries(self, stats_summaries_list,
feature_ids_list, last_layer_nodes_range): last_layer_nodes_range):
"""Updates ensemble based on the best gains from stats summaries.""" """Updates ensemble based on the best gains from stats summaries."""
node_ids_per_feature = [] node_ids_per_feature = []
gains_list = [] gains_list = []
@ -476,11 +481,11 @@ class _EnsembleGrower(object):
left_node_contribs_list = [] left_node_contribs_list = []
right_node_contribs_list = [] right_node_contribs_list = []
all_feature_ids = [] all_feature_ids = []
assert len(stats_summaries_list) == len(feature_ids_list) assert len(stats_summaries_list) == len(self._feature_ids_list)
max_splits = _get_max_splits(self._tree_hparams) max_splits = _get_max_splits(self._tree_hparams)
for i, feature_ids in enumerate(feature_ids_list): for i, feature_ids in enumerate(self._feature_ids_list):
(numeric_node_ids_per_feature, numeric_gains_list, (numeric_node_ids_per_feature, numeric_gains_list,
numeric_thresholds_list, numeric_left_node_contribs_list, numeric_thresholds_list, numeric_left_node_contribs_list,
numeric_right_node_contribs_list) = ( numeric_right_node_contribs_list) = (
@ -516,12 +521,13 @@ class _EnsembleGrower(object):
class _InMemoryEnsembleGrower(_EnsembleGrower): class _InMemoryEnsembleGrower(_EnsembleGrower):
"""A base class for ensemble growers.""" """An in-memory ensemble grower."""
def __init__(self, tree_ensemble, tree_hparams): def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
super(_InMemoryEnsembleGrower, self).__init__( super(_InMemoryEnsembleGrower, self).__init__(
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
feature_ids_list=feature_ids_list)
def center_bias(self, center_bias_var, gradients, hessians): def center_bias(self, center_bias_var, gradients, hessians):
# For in memory, we already have a full batch of gradients and hessians, # For in memory, we already have a full batch of gradients and hessians,
@ -531,83 +537,98 @@ class _InMemoryEnsembleGrower(_EnsembleGrower):
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0) mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians) return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
def grow_tree(self, stats_summaries_list, feature_ids_list, def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
last_layer_nodes_range):
# For in memory, we already have full data in one batch, so we can grow the # For in memory, we already have full data in one batch, so we can grow the
# tree immediately. # tree immediately.
return self._grow_tree_from_stats_summaries( return self._grow_tree_from_stats_summaries(
stats_summaries_list, feature_ids_list, last_layer_nodes_range) stats_summaries_list, last_layer_nodes_range)
class _AccumulatorEnsembleGrower(_EnsembleGrower): class _AccumulatorEnsembleGrower(_EnsembleGrower):
"""A base class for ensemble growers.""" """An accumulator based ensemble grower."""
def __init__(self, tree_ensemble, tree_hparams, stamp_token, def __init__(self, tree_ensemble, tree_hparams, stamp_token,
n_batches_per_layer, bucket_size_list, is_chief): n_batches_per_layer, bucket_size_list, is_chief, center_bias,
feature_ids_list):
super(_AccumulatorEnsembleGrower, self).__init__( super(_AccumulatorEnsembleGrower, self).__init__(
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
feature_ids_list=feature_ids_list)
self._stamp_token = stamp_token self._stamp_token = stamp_token
self._n_batches_per_layer = n_batches_per_layer self._n_batches_per_layer = n_batches_per_layer
self._bucket_size_list = bucket_size_list self._bucket_size_list = bucket_size_list
self._is_chief = is_chief self._is_chief = is_chief
self._growing_accumulators = []
self._chief_init_ops = []
max_splits = _get_max_splits(self._tree_hparams)
for i, feature_ids in enumerate(self._feature_ids_list):
accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of grads and hessians (the last dimension).
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
shared_name='numeric_stats_summary_accumulator_' + str(i))
self._chief_init_ops.append(
accumulator.set_global_step(self._stamp_token))
self._growing_accumulators.append(accumulator)
self._center_bias = center_bias
if center_bias:
self._bias_accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of grads and hessians means only.
# TODO(nponomareva): this will change for a multiclass
shape=[2, 1],
shared_name='bias_accumulator')
self._chief_init_ops.append(
self._bias_accumulator.set_global_step(self._stamp_token))
def center_bias(self, center_bias_var, gradients, hessians): def center_bias(self, center_bias_var, gradients, hessians):
# For not in memory situation, we need to accumulate enough of batches first # For not in memory situation, we need to accumulate enough of batches first
# before proceeding with centering bias. # before proceeding with centering bias.
# Create an accumulator. # Create an accumulator.
if not self._center_bias:
raise RuntimeError('center_bias called but bias centering is disabled.')
bias_dependencies = [] bias_dependencies = []
bias_accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of grads and hessians means only.
# TODO(nponomareva): this will change for a multiclass
shape=[2, 1],
shared_name='bias_accumulator')
grads_and_hess = array_ops.stack([gradients, hessians], axis=0) grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1) grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token) apply_grad = self._bias_accumulator.apply_grad(
grads_and_hess, self._stamp_token)
bias_dependencies.append(apply_grad) bias_dependencies.append(apply_grad)
# Center bias if enough batches were processed. # Center bias if enough batches were processed.
with ops.control_dependencies(bias_dependencies): with ops.control_dependencies(bias_dependencies):
if not self._is_chief: if not self._is_chief:
return control_flow_ops.no_op() return control_flow_ops.no_op()
def _set_accumulators_stamp():
return control_flow_ops.group(
[acc.set_global_step(self._stamp_token + 1) for acc in
self._growing_accumulators])
def center_bias_from_accumulator(): def center_bias_from_accumulator():
accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0) accumulated = array_ops.unstack(self._bias_accumulator.take_grad(1),
return self._center_bias_fn(center_bias_var, axis=0)
center_bias_op = self._center_bias_fn(
center_bias_var,
array_ops.expand_dims(accumulated[0], 0), array_ops.expand_dims(accumulated[0], 0),
array_ops.expand_dims(accumulated[1], 0)) array_ops.expand_dims(accumulated[1], 0))
with ops.control_dependencies([center_bias_op]):
return control_flow_ops.cond(center_bias_var,
control_flow_ops.no_op,
_set_accumulators_stamp)
center_bias_op = control_flow_ops.cond( center_bias_op = control_flow_ops.cond(
math_ops.greater_equal(bias_accumulator.num_accumulated(), math_ops.greater_equal(self._bias_accumulator.num_accumulated(),
self._n_batches_per_layer), self._n_batches_per_layer),
center_bias_from_accumulator, center_bias_from_accumulator,
control_flow_ops.no_op, control_flow_ops.no_op,
name='wait_until_n_batches_for_bias_accumulated') name='wait_until_n_batches_for_bias_accumulated')
return center_bias_op return center_bias_op
def grow_tree(self, stats_summaries_list, feature_ids_list, def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
last_layer_nodes_range):
# For not in memory situation, we need to accumulate enough of batches first
# before proceeding with building a tree layer.
max_splits = _get_max_splits(self._tree_hparams)
# Prepare accumulators.
accumulators = []
dependencies = [] dependencies = []
for i, feature_ids in enumerate(feature_ids_list): for i in range(len(self._feature_ids_list)):
stats_summaries = stats_summaries_list[i] stats_summaries = stats_summaries_list[i]
accumulator = data_flow_ops.ConditionalAccumulator( apply_grad = self._growing_accumulators[i].apply_grad(
dtype=dtypes.float32,
# The stats consist of grads and hessians (the last dimension).
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
shared_name='numeric_stats_summary_accumulator_' + str(i))
accumulators.append(accumulator)
apply_grad = accumulator.apply_grad(
array_ops.stack(stats_summaries, axis=0), self._stamp_token) array_ops.stack(stats_summaries, axis=0), self._stamp_token)
dependencies.append(apply_grad) dependencies.append(apply_grad)
@ -617,7 +638,8 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
return control_flow_ops.no_op() return control_flow_ops.no_op()
min_accumulated = math_ops.reduce_min( min_accumulated = math_ops.reduce_min(
array_ops.stack([acc.num_accumulated() for acc in accumulators])) array_ops.stack([acc.num_accumulated() for acc in
self._growing_accumulators]))
def grow_tree_from_accumulated_summaries_fn(): def grow_tree_from_accumulated_summaries_fn():
"""Updates tree with the best layer from accumulated summaries.""" """Updates tree with the best layer from accumulated summaries."""
@ -625,10 +647,11 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
stats_summaries_list = [] stats_summaries_list = []
stats_summaries_list = [ stats_summaries_list = [
array_ops.unstack(accumulator.take_grad(1), axis=0) array_ops.unstack(accumulator.take_grad(1), axis=0)
for accumulator in accumulators for accumulator in self._growing_accumulators
] ]
grow_op = self._grow_tree_from_stats_summaries( grow_op = self._grow_tree_from_stats_summaries(
stats_summaries_list, feature_ids_list, last_layer_nodes_range) stats_summaries_list, last_layer_nodes_range
)
return grow_op return grow_op
grow_model = control_flow_ops.cond( grow_model = control_flow_ops.cond(
@ -638,6 +661,10 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
name='wait_until_n_batches_accumulated') name='wait_until_n_batches_accumulated')
return grow_model return grow_model
def chief_init_op(self):
"""Ops that chief needs to run to initialize the state."""
return control_flow_ops.group(self._chief_init_ops)
def _bt_model_fn( def _bt_model_fn(
features, features,
@ -683,21 +710,7 @@ def _bt_model_fn(
Raises: Raises:
ValueError: mode or params are invalid, or features has the wrong type. ValueError: mode or params are invalid, or features has the wrong type.
""" """
is_single_machine = (config.num_worker_replicas <= 1)
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name) sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
center_bias = tree_hparams.center_bias
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
'dataset as a single batch, and n_batches_per_layer should be set as '
'1.')
if (not config.is_chief or config.num_worker_replicas > 1 or
config.num_ps_replicas > 0):
raise ValueError('train_in_memory is supported only for '
'non-distributed training.')
worker_device = control_flow_ops.no_op().device
train_op = []
with ops.name_scope(name) as name: with ops.name_scope(name) as name:
# Prepare. # Prepare.
global_step = training_util.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
@ -724,6 +737,20 @@ def _bt_model_fn(
logits=logits) logits=logits)
# ============== Training graph ============== # ============== Training graph ==============
center_bias = tree_hparams.center_bias
is_single_machine = (config.num_worker_replicas <= 1)
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
'dataset as a single batch, and n_batches_per_layer should be set as '
'1.')
if (not config.is_chief or config.num_worker_replicas > 1 or
config.num_ps_replicas > 0):
raise ValueError('train_in_memory is supported only for '
'non-distributed training.')
worker_device = control_flow_ops.no_op().device
train_op = []
# Extract input features and set up cache for training. # Extract input features and set up cache for training.
training_state_cache = None training_state_cache = None
if train_in_memory: if train_in_memory:
@ -742,22 +769,6 @@ def _bt_model_fn(
example_ids = features[example_id_column_name] example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable( training_state_cache = _CacheTrainingStatesUsingHashTable(
example_ids, head.logits_dimension) example_ids, head.logits_dimension)
# Variable that determines whether bias centering is needed.
center_bias_var = variable_scope.variable(
initial_value=center_bias, name='center_bias_needed', trainable=False)
if is_single_machine:
local_tree_ensemble = tree_ensemble
ensemble_reload = control_flow_ops.no_op()
else:
# Have a local copy of ensemble for the distributed setting.
with ops.device(worker_device):
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
name=name + '_local', is_local=True)
# TODO(soroush): Do partial updates if this becomes a bottleneck.
ensemble_reload = local_tree_ensemble.deserialize(
*tree_ensemble.serialize())
if training_state_cache: if training_state_cache:
cached_tree_ids, cached_node_ids, cached_logits = ( cached_tree_ids, cached_node_ids, cached_logits = (
training_state_cache.lookup()) training_state_cache.lookup())
@ -770,13 +781,20 @@ def _bt_model_fn(
array_ops.zeros( array_ops.zeros(
[batch_size, head.logits_dimension], dtype=dtypes.float32)) [batch_size, head.logits_dimension], dtype=dtypes.float32))
if is_single_machine:
local_tree_ensemble = tree_ensemble
ensemble_reload = control_flow_ops.no_op()
else:
# Have a local copy of ensemble for the distributed setting.
with ops.device(worker_device):
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
name=name + '_local', is_local=True)
# TODO(soroush): Do partial updates if this becomes a bottleneck.
ensemble_reload = local_tree_ensemble.deserialize(
*tree_ensemble.serialize())
with ops.control_dependencies([ensemble_reload]): with ops.control_dependencies([ensemble_reload]):
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers, (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
last_layer_nodes_range) = local_tree_ensemble.get_states() last_layer_nodes_range) = local_tree_ensemble.get_states()
summary.scalar('ensemble/num_trees', num_trees)
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict( partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
tree_ensemble_handle=local_tree_ensemble.resource_handle, tree_ensemble_handle=local_tree_ensemble.resource_handle,
cached_tree_ids=cached_tree_ids, cached_tree_ids=cached_tree_ids,
@ -785,6 +803,24 @@ def _bt_model_fn(
logits_dimension=head.logits_dimension) logits_dimension=head.logits_dimension)
logits = cached_logits + partial_logits logits = cached_logits + partial_logits
if train_in_memory:
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams,
feature_ids_list=feature_ids_list)
else:
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
stamp_token, n_batches_per_layer,
bucket_size_list, config.is_chief,
center_bias=center_bias,
feature_ids_list=feature_ids_list)
summary.scalar('ensemble/num_trees', num_trees)
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
# Variable that determines whether bias centering is needed.
center_bias_var = variable_scope.variable(
initial_value=center_bias, name='center_bias_needed', trainable=False,
use_resource=True)
# Create training graph. # Create training graph.
def _train_op_fn(loss): def _train_op_fn(loss):
"""Run one training iteration.""" """Run one training iteration."""
@ -823,14 +859,7 @@ def _bt_model_fn(
axis=0) for f in feature_ids axis=0) for f in feature_ids
] ]
stats_summaries_list.append(summaries) stats_summaries_list.append(summaries)
if center_bias:
if train_in_memory and is_single_machine:
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
else:
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
stamp_token, n_batches_per_layer,
bucket_size_list, config.is_chief)
update_model = control_flow_ops.cond( update_model = control_flow_ops.cond(
center_bias_var, center_bias_var,
functools.partial( functools.partial(
@ -840,7 +869,10 @@ def _bt_model_fn(
hessians, hessians,
), ),
functools.partial(grower.grow_tree, stats_summaries_list, functools.partial(grower.grow_tree, stats_summaries_list,
feature_ids_list, last_layer_nodes_range)) last_layer_nodes_range))
else:
update_model = grower.grow_tree(stats_summaries_list,
last_layer_nodes_range)
train_op.append(update_model) train_op.append(update_model)
with ops.control_dependencies([update_model]): with ops.control_dependencies([update_model]):
@ -859,10 +891,22 @@ def _bt_model_fn(
estimator_spec = estimator_spec._replace( estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks + training_hooks=estimator_spec.training_hooks +
(_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers, (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
tree_hparams.n_trees, tree_hparams.max_depth),)) tree_hparams.n_trees, tree_hparams.max_depth),),
training_chief_hooks=[GrowerInitializationHook(grower.chief_init_op())] +
list(estimator_spec.training_chief_hooks))
return estimator_spec return estimator_spec
class GrowerInitializationHook(session_run_hook.SessionRunHook):
"""A SessionRunHook handles initialization of `_EnsembleGrower`."""
def __init__(self, init_op):
self._init_op = init_op
def after_create_session(self, session, coord):
session.run(self._init_op)
def _create_classification_head(n_classes, def _create_classification_head(n_classes,
weight_column=None, weight_column=None,
label_vocabulary=None): label_vocabulary=None):

View File

@ -173,6 +173,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1) eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0) self.assertAllClose(eval_res['accuracy'], 1.0)
def testTrainTwiceAndEvaluateBinaryClassifier(self):
input_fn = _make_train_input_fn(is_classification=True)
est = boosted_trees.BoostedTreesClassifier(
feature_columns=self._feature_columns,
n_batches_per_layer=1,
n_trees=5,
max_depth=10)
num_steps = 2
# Train for a few steps, and validate final checkpoint.
est.train(input_fn, steps=num_steps)
est.train(input_fn, steps=num_steps)
self._assert_checkpoint(
est.model_dir, global_step=num_steps * 2,
finalized_trees=0, attempted_layers=4)
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
def testInferBinaryClassifier(self): def testInferBinaryClassifier(self):
train_input_fn = _make_train_input_fn(is_classification=True) train_input_fn = _make_train_input_fn(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn( predict_input_fn = numpy_io.numpy_input_fn(