Synchronize stamp token between different resources:

* During bias_centering, the global_step on per-feature accumulators are not updated, so once the bias centering is done, we now update the accumulators to have in-sync global step.
* When a parameter server is restarted, the accumulators need to be synchronized with the stamp token after checkpoint is reloaded.

PiperOrigin-RevId: 209617008
This commit is contained in:
A. Unique TensorFlower 2018-08-21 10:25:05 -07:00 committed by TensorFlower Gardener
parent 212d978a2d
commit 792a933b11
2 changed files with 166 additions and 102 deletions

View File

@ -404,18 +404,21 @@ class _EnsembleGrower(object):
training_ops.append(grow_op)
"""
def __init__(self, tree_ensemble, tree_hparams):
def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
"""Initializes a grower object.
Args:
tree_ensemble: A TreeEnsemble variable.
tree_hparams: TODO. collections.namedtuple for hyper parameters.
feature_ids_list: a list of lists of feature ids for each bucket size.
Raises:
ValueError: when pruning mode is invalid or pruning is used and no tree
complexity is set.
"""
self._tree_ensemble = tree_ensemble
self._tree_hparams = tree_hparams
self._feature_ids_list = feature_ids_list
# pylint: disable=protected-access
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
@ -440,14 +443,12 @@ class _EnsembleGrower(object):
"""
@abc.abstractmethod
def grow_tree(self, stats_summaries_list, feature_ids_list,
last_layer_nodes_range):
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
"""Grows a tree, if ready, based on provided statistics.
Args:
stats_summaries_list: List of stats summary tensors, representing sums of
gradients and hessians for each feature bucket.
feature_ids_list: a list of lists of feature ids for each bucket size.
last_layer_nodes_range: A tensor representing ids of the nodes in the
current layer, to be split.
@ -455,6 +456,10 @@ class _EnsembleGrower(object):
An op for growing a tree.
"""
def chief_init_op(self):
"""Ops that chief needs to run to initialize the state."""
return control_flow_ops.no_op()
# ============= Helper methods ===========
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
@ -468,7 +473,7 @@ class _EnsembleGrower(object):
return center_bias_var.assign(continue_centering)
def _grow_tree_from_stats_summaries(self, stats_summaries_list,
feature_ids_list, last_layer_nodes_range):
last_layer_nodes_range):
"""Updates ensemble based on the best gains from stats summaries."""
node_ids_per_feature = []
gains_list = []
@ -476,11 +481,11 @@ class _EnsembleGrower(object):
left_node_contribs_list = []
right_node_contribs_list = []
all_feature_ids = []
assert len(stats_summaries_list) == len(feature_ids_list)
assert len(stats_summaries_list) == len(self._feature_ids_list)
max_splits = _get_max_splits(self._tree_hparams)
for i, feature_ids in enumerate(feature_ids_list):
for i, feature_ids in enumerate(self._feature_ids_list):
(numeric_node_ids_per_feature, numeric_gains_list,
numeric_thresholds_list, numeric_left_node_contribs_list,
numeric_right_node_contribs_list) = (
@ -516,12 +521,13 @@ class _EnsembleGrower(object):
class _InMemoryEnsembleGrower(_EnsembleGrower):
"""A base class for ensemble growers."""
"""An in-memory ensemble grower."""
def __init__(self, tree_ensemble, tree_hparams):
def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
super(_InMemoryEnsembleGrower, self).__init__(
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
feature_ids_list=feature_ids_list)
def center_bias(self, center_bias_var, gradients, hessians):
# For in memory, we already have a full batch of gradients and hessians,
@ -531,83 +537,98 @@ class _InMemoryEnsembleGrower(_EnsembleGrower):
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
def grow_tree(self, stats_summaries_list, feature_ids_list,
last_layer_nodes_range):
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
# For in memory, we already have full data in one batch, so we can grow the
# tree immediately.
return self._grow_tree_from_stats_summaries(
stats_summaries_list, feature_ids_list, last_layer_nodes_range)
stats_summaries_list, last_layer_nodes_range)
class _AccumulatorEnsembleGrower(_EnsembleGrower):
"""A base class for ensemble growers."""
"""An accumulator based ensemble grower."""
def __init__(self, tree_ensemble, tree_hparams, stamp_token,
n_batches_per_layer, bucket_size_list, is_chief):
n_batches_per_layer, bucket_size_list, is_chief, center_bias,
feature_ids_list):
super(_AccumulatorEnsembleGrower, self).__init__(
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
feature_ids_list=feature_ids_list)
self._stamp_token = stamp_token
self._n_batches_per_layer = n_batches_per_layer
self._bucket_size_list = bucket_size_list
self._is_chief = is_chief
self._growing_accumulators = []
self._chief_init_ops = []
max_splits = _get_max_splits(self._tree_hparams)
for i, feature_ids in enumerate(self._feature_ids_list):
accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of grads and hessians (the last dimension).
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
shared_name='numeric_stats_summary_accumulator_' + str(i))
self._chief_init_ops.append(
accumulator.set_global_step(self._stamp_token))
self._growing_accumulators.append(accumulator)
self._center_bias = center_bias
if center_bias:
self._bias_accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of grads and hessians means only.
# TODO(nponomareva): this will change for a multiclass
shape=[2, 1],
shared_name='bias_accumulator')
self._chief_init_ops.append(
self._bias_accumulator.set_global_step(self._stamp_token))
def center_bias(self, center_bias_var, gradients, hessians):
# For not in memory situation, we need to accumulate enough of batches first
# before proceeding with centering bias.
# Create an accumulator.
if not self._center_bias:
raise RuntimeError('center_bias called but bias centering is disabled.')
bias_dependencies = []
bias_accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of grads and hessians means only.
# TODO(nponomareva): this will change for a multiclass
shape=[2, 1],
shared_name='bias_accumulator')
grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
apply_grad = self._bias_accumulator.apply_grad(
grads_and_hess, self._stamp_token)
bias_dependencies.append(apply_grad)
# Center bias if enough batches were processed.
with ops.control_dependencies(bias_dependencies):
if not self._is_chief:
return control_flow_ops.no_op()
def _set_accumulators_stamp():
return control_flow_ops.group(
[acc.set_global_step(self._stamp_token + 1) for acc in
self._growing_accumulators])
def center_bias_from_accumulator():
accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
return self._center_bias_fn(center_bias_var,
array_ops.expand_dims(accumulated[0], 0),
array_ops.expand_dims(accumulated[1], 0))
accumulated = array_ops.unstack(self._bias_accumulator.take_grad(1),
axis=0)
center_bias_op = self._center_bias_fn(
center_bias_var,
array_ops.expand_dims(accumulated[0], 0),
array_ops.expand_dims(accumulated[1], 0))
with ops.control_dependencies([center_bias_op]):
return control_flow_ops.cond(center_bias_var,
control_flow_ops.no_op,
_set_accumulators_stamp)
center_bias_op = control_flow_ops.cond(
math_ops.greater_equal(bias_accumulator.num_accumulated(),
math_ops.greater_equal(self._bias_accumulator.num_accumulated(),
self._n_batches_per_layer),
center_bias_from_accumulator,
control_flow_ops.no_op,
name='wait_until_n_batches_for_bias_accumulated')
return center_bias_op
def grow_tree(self, stats_summaries_list, feature_ids_list,
last_layer_nodes_range):
# For not in memory situation, we need to accumulate enough of batches first
# before proceeding with building a tree layer.
max_splits = _get_max_splits(self._tree_hparams)
# Prepare accumulators.
accumulators = []
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
dependencies = []
for i, feature_ids in enumerate(feature_ids_list):
for i in range(len(self._feature_ids_list)):
stats_summaries = stats_summaries_list[i]
accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
# The stats consist of grads and hessians (the last dimension).
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
shared_name='numeric_stats_summary_accumulator_' + str(i))
accumulators.append(accumulator)
apply_grad = accumulator.apply_grad(
apply_grad = self._growing_accumulators[i].apply_grad(
array_ops.stack(stats_summaries, axis=0), self._stamp_token)
dependencies.append(apply_grad)
@ -617,7 +638,8 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
return control_flow_ops.no_op()
min_accumulated = math_ops.reduce_min(
array_ops.stack([acc.num_accumulated() for acc in accumulators]))
array_ops.stack([acc.num_accumulated() for acc in
self._growing_accumulators]))
def grow_tree_from_accumulated_summaries_fn():
"""Updates tree with the best layer from accumulated summaries."""
@ -625,10 +647,11 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
stats_summaries_list = []
stats_summaries_list = [
array_ops.unstack(accumulator.take_grad(1), axis=0)
for accumulator in accumulators
for accumulator in self._growing_accumulators
]
grow_op = self._grow_tree_from_stats_summaries(
stats_summaries_list, feature_ids_list, last_layer_nodes_range)
stats_summaries_list, last_layer_nodes_range
)
return grow_op
grow_model = control_flow_ops.cond(
@ -638,6 +661,10 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
name='wait_until_n_batches_accumulated')
return grow_model
def chief_init_op(self):
"""Ops that chief needs to run to initialize the state."""
return control_flow_ops.group(self._chief_init_ops)
def _bt_model_fn(
features,
@ -683,21 +710,7 @@ def _bt_model_fn(
Raises:
ValueError: mode or params are invalid, or features has the wrong type.
"""
is_single_machine = (config.num_worker_replicas <= 1)
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
center_bias = tree_hparams.center_bias
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
'dataset as a single batch, and n_batches_per_layer should be set as '
'1.')
if (not config.is_chief or config.num_worker_replicas > 1 or
config.num_ps_replicas > 0):
raise ValueError('train_in_memory is supported only for '
'non-distributed training.')
worker_device = control_flow_ops.no_op().device
train_op = []
with ops.name_scope(name) as name:
# Prepare.
global_step = training_util.get_or_create_global_step()
@ -724,6 +737,20 @@ def _bt_model_fn(
logits=logits)
# ============== Training graph ==============
center_bias = tree_hparams.center_bias
is_single_machine = (config.num_worker_replicas <= 1)
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
'dataset as a single batch, and n_batches_per_layer should be set as '
'1.')
if (not config.is_chief or config.num_worker_replicas > 1 or
config.num_ps_replicas > 0):
raise ValueError('train_in_memory is supported only for '
'non-distributed training.')
worker_device = control_flow_ops.no_op().device
train_op = []
# Extract input features and set up cache for training.
training_state_cache = None
if train_in_memory:
@ -742,22 +769,6 @@ def _bt_model_fn(
example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable(
example_ids, head.logits_dimension)
# Variable that determines whether bias centering is needed.
center_bias_var = variable_scope.variable(
initial_value=center_bias, name='center_bias_needed', trainable=False)
if is_single_machine:
local_tree_ensemble = tree_ensemble
ensemble_reload = control_flow_ops.no_op()
else:
# Have a local copy of ensemble for the distributed setting.
with ops.device(worker_device):
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
name=name + '_local', is_local=True)
# TODO(soroush): Do partial updates if this becomes a bottleneck.
ensemble_reload = local_tree_ensemble.deserialize(
*tree_ensemble.serialize())
if training_state_cache:
cached_tree_ids, cached_node_ids, cached_logits = (
training_state_cache.lookup())
@ -770,21 +781,46 @@ def _bt_model_fn(
array_ops.zeros(
[batch_size, head.logits_dimension], dtype=dtypes.float32))
if is_single_machine:
local_tree_ensemble = tree_ensemble
ensemble_reload = control_flow_ops.no_op()
else:
# Have a local copy of ensemble for the distributed setting.
with ops.device(worker_device):
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
name=name + '_local', is_local=True)
# TODO(soroush): Do partial updates if this becomes a bottleneck.
ensemble_reload = local_tree_ensemble.deserialize(
*tree_ensemble.serialize())
with ops.control_dependencies([ensemble_reload]):
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
last_layer_nodes_range) = local_tree_ensemble.get_states()
summary.scalar('ensemble/num_trees', num_trees)
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
tree_ensemble_handle=local_tree_ensemble.resource_handle,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
logits = cached_logits + partial_logits
logits = cached_logits + partial_logits
if train_in_memory:
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams,
feature_ids_list=feature_ids_list)
else:
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
stamp_token, n_batches_per_layer,
bucket_size_list, config.is_chief,
center_bias=center_bias,
feature_ids_list=feature_ids_list)
summary.scalar('ensemble/num_trees', num_trees)
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
# Variable that determines whether bias centering is needed.
center_bias_var = variable_scope.variable(
initial_value=center_bias, name='center_bias_needed', trainable=False,
use_resource=True)
# Create training graph.
def _train_op_fn(loss):
"""Run one training iteration."""
@ -823,24 +859,20 @@ def _bt_model_fn(
axis=0) for f in feature_ids
]
stats_summaries_list.append(summaries)
if train_in_memory and is_single_machine:
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
if center_bias:
update_model = control_flow_ops.cond(
center_bias_var,
functools.partial(
grower.center_bias,
center_bias_var,
gradients,
hessians,
),
functools.partial(grower.grow_tree, stats_summaries_list,
last_layer_nodes_range))
else:
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
stamp_token, n_batches_per_layer,
bucket_size_list, config.is_chief)
update_model = control_flow_ops.cond(
center_bias_var,
functools.partial(
grower.center_bias,
center_bias_var,
gradients,
hessians,
),
functools.partial(grower.grow_tree, stats_summaries_list,
feature_ids_list, last_layer_nodes_range))
update_model = grower.grow_tree(stats_summaries_list,
last_layer_nodes_range)
train_op.append(update_model)
with ops.control_dependencies([update_model]):
@ -859,10 +891,22 @@ def _bt_model_fn(
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
(_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
tree_hparams.n_trees, tree_hparams.max_depth),))
tree_hparams.n_trees, tree_hparams.max_depth),),
training_chief_hooks=[GrowerInitializationHook(grower.chief_init_op())] +
list(estimator_spec.training_chief_hooks))
return estimator_spec
class GrowerInitializationHook(session_run_hook.SessionRunHook):
"""A SessionRunHook handles initialization of `_EnsembleGrower`."""
def __init__(self, init_op):
self._init_op = init_op
def after_create_session(self, session, coord):
session.run(self._init_op)
def _create_classification_head(n_classes,
weight_column=None,
label_vocabulary=None):

View File

@ -173,6 +173,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
def testTrainTwiceAndEvaluateBinaryClassifier(self):
input_fn = _make_train_input_fn(is_classification=True)
est = boosted_trees.BoostedTreesClassifier(
feature_columns=self._feature_columns,
n_batches_per_layer=1,
n_trees=5,
max_depth=10)
num_steps = 2
# Train for a few steps, and validate final checkpoint.
est.train(input_fn, steps=num_steps)
est.train(input_fn, steps=num_steps)
self._assert_checkpoint(
est.model_dir, global_step=num_steps * 2,
finalized_trees=0, attempted_layers=4)
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
def testInferBinaryClassifier(self):
train_input_fn = _make_train_input_fn(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(