Synchronize stamp token between different resources:
* During bias_centering, the global_step on per-feature accumulators are not updated, so once the bias centering is done, we now update the accumulators to have in-sync global step. * When a parameter server is restarted, the accumulators need to be synchronized with the stamp token after checkpoint is reloaded. PiperOrigin-RevId: 209617008
This commit is contained in:
parent
212d978a2d
commit
792a933b11
@ -404,18 +404,21 @@ class _EnsembleGrower(object):
|
||||
training_ops.append(grow_op)
|
||||
"""
|
||||
|
||||
def __init__(self, tree_ensemble, tree_hparams):
|
||||
def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
|
||||
"""Initializes a grower object.
|
||||
|
||||
Args:
|
||||
tree_ensemble: A TreeEnsemble variable.
|
||||
tree_hparams: TODO. collections.namedtuple for hyper parameters.
|
||||
feature_ids_list: a list of lists of feature ids for each bucket size.
|
||||
|
||||
Raises:
|
||||
ValueError: when pruning mode is invalid or pruning is used and no tree
|
||||
complexity is set.
|
||||
"""
|
||||
self._tree_ensemble = tree_ensemble
|
||||
self._tree_hparams = tree_hparams
|
||||
self._feature_ids_list = feature_ids_list
|
||||
# pylint: disable=protected-access
|
||||
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
|
||||
tree_hparams.pruning_mode)
|
||||
@ -440,14 +443,12 @@ class _EnsembleGrower(object):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def grow_tree(self, stats_summaries_list, feature_ids_list,
|
||||
last_layer_nodes_range):
|
||||
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
|
||||
"""Grows a tree, if ready, based on provided statistics.
|
||||
|
||||
Args:
|
||||
stats_summaries_list: List of stats summary tensors, representing sums of
|
||||
gradients and hessians for each feature bucket.
|
||||
feature_ids_list: a list of lists of feature ids for each bucket size.
|
||||
last_layer_nodes_range: A tensor representing ids of the nodes in the
|
||||
current layer, to be split.
|
||||
|
||||
@ -455,6 +456,10 @@ class _EnsembleGrower(object):
|
||||
An op for growing a tree.
|
||||
"""
|
||||
|
||||
def chief_init_op(self):
|
||||
"""Ops that chief needs to run to initialize the state."""
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
# ============= Helper methods ===========
|
||||
|
||||
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
|
||||
@ -468,7 +473,7 @@ class _EnsembleGrower(object):
|
||||
return center_bias_var.assign(continue_centering)
|
||||
|
||||
def _grow_tree_from_stats_summaries(self, stats_summaries_list,
|
||||
feature_ids_list, last_layer_nodes_range):
|
||||
last_layer_nodes_range):
|
||||
"""Updates ensemble based on the best gains from stats summaries."""
|
||||
node_ids_per_feature = []
|
||||
gains_list = []
|
||||
@ -476,11 +481,11 @@ class _EnsembleGrower(object):
|
||||
left_node_contribs_list = []
|
||||
right_node_contribs_list = []
|
||||
all_feature_ids = []
|
||||
assert len(stats_summaries_list) == len(feature_ids_list)
|
||||
assert len(stats_summaries_list) == len(self._feature_ids_list)
|
||||
|
||||
max_splits = _get_max_splits(self._tree_hparams)
|
||||
|
||||
for i, feature_ids in enumerate(feature_ids_list):
|
||||
for i, feature_ids in enumerate(self._feature_ids_list):
|
||||
(numeric_node_ids_per_feature, numeric_gains_list,
|
||||
numeric_thresholds_list, numeric_left_node_contribs_list,
|
||||
numeric_right_node_contribs_list) = (
|
||||
@ -516,12 +521,13 @@ class _EnsembleGrower(object):
|
||||
|
||||
|
||||
class _InMemoryEnsembleGrower(_EnsembleGrower):
|
||||
"""A base class for ensemble growers."""
|
||||
"""An in-memory ensemble grower."""
|
||||
|
||||
def __init__(self, tree_ensemble, tree_hparams):
|
||||
def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
|
||||
|
||||
super(_InMemoryEnsembleGrower, self).__init__(
|
||||
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
|
||||
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
|
||||
feature_ids_list=feature_ids_list)
|
||||
|
||||
def center_bias(self, center_bias_var, gradients, hessians):
|
||||
# For in memory, we already have a full batch of gradients and hessians,
|
||||
@ -531,83 +537,98 @@ class _InMemoryEnsembleGrower(_EnsembleGrower):
|
||||
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
|
||||
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
|
||||
|
||||
def grow_tree(self, stats_summaries_list, feature_ids_list,
|
||||
last_layer_nodes_range):
|
||||
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
|
||||
# For in memory, we already have full data in one batch, so we can grow the
|
||||
# tree immediately.
|
||||
return self._grow_tree_from_stats_summaries(
|
||||
stats_summaries_list, feature_ids_list, last_layer_nodes_range)
|
||||
stats_summaries_list, last_layer_nodes_range)
|
||||
|
||||
|
||||
class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
||||
"""A base class for ensemble growers."""
|
||||
"""An accumulator based ensemble grower."""
|
||||
|
||||
def __init__(self, tree_ensemble, tree_hparams, stamp_token,
|
||||
n_batches_per_layer, bucket_size_list, is_chief):
|
||||
n_batches_per_layer, bucket_size_list, is_chief, center_bias,
|
||||
feature_ids_list):
|
||||
super(_AccumulatorEnsembleGrower, self).__init__(
|
||||
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
|
||||
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
|
||||
feature_ids_list=feature_ids_list)
|
||||
self._stamp_token = stamp_token
|
||||
self._n_batches_per_layer = n_batches_per_layer
|
||||
self._bucket_size_list = bucket_size_list
|
||||
self._is_chief = is_chief
|
||||
self._growing_accumulators = []
|
||||
self._chief_init_ops = []
|
||||
max_splits = _get_max_splits(self._tree_hparams)
|
||||
for i, feature_ids in enumerate(self._feature_ids_list):
|
||||
accumulator = data_flow_ops.ConditionalAccumulator(
|
||||
dtype=dtypes.float32,
|
||||
# The stats consist of grads and hessians (the last dimension).
|
||||
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
|
||||
shared_name='numeric_stats_summary_accumulator_' + str(i))
|
||||
self._chief_init_ops.append(
|
||||
accumulator.set_global_step(self._stamp_token))
|
||||
self._growing_accumulators.append(accumulator)
|
||||
self._center_bias = center_bias
|
||||
if center_bias:
|
||||
self._bias_accumulator = data_flow_ops.ConditionalAccumulator(
|
||||
dtype=dtypes.float32,
|
||||
# The stats consist of grads and hessians means only.
|
||||
# TODO(nponomareva): this will change for a multiclass
|
||||
shape=[2, 1],
|
||||
shared_name='bias_accumulator')
|
||||
self._chief_init_ops.append(
|
||||
self._bias_accumulator.set_global_step(self._stamp_token))
|
||||
|
||||
def center_bias(self, center_bias_var, gradients, hessians):
|
||||
# For not in memory situation, we need to accumulate enough of batches first
|
||||
# before proceeding with centering bias.
|
||||
|
||||
# Create an accumulator.
|
||||
if not self._center_bias:
|
||||
raise RuntimeError('center_bias called but bias centering is disabled.')
|
||||
bias_dependencies = []
|
||||
bias_accumulator = data_flow_ops.ConditionalAccumulator(
|
||||
dtype=dtypes.float32,
|
||||
# The stats consist of grads and hessians means only.
|
||||
# TODO(nponomareva): this will change for a multiclass
|
||||
shape=[2, 1],
|
||||
shared_name='bias_accumulator')
|
||||
|
||||
grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
|
||||
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
|
||||
|
||||
apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
|
||||
apply_grad = self._bias_accumulator.apply_grad(
|
||||
grads_and_hess, self._stamp_token)
|
||||
bias_dependencies.append(apply_grad)
|
||||
|
||||
# Center bias if enough batches were processed.
|
||||
with ops.control_dependencies(bias_dependencies):
|
||||
if not self._is_chief:
|
||||
return control_flow_ops.no_op()
|
||||
def _set_accumulators_stamp():
|
||||
return control_flow_ops.group(
|
||||
[acc.set_global_step(self._stamp_token + 1) for acc in
|
||||
self._growing_accumulators])
|
||||
|
||||
def center_bias_from_accumulator():
|
||||
accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
|
||||
return self._center_bias_fn(center_bias_var,
|
||||
accumulated = array_ops.unstack(self._bias_accumulator.take_grad(1),
|
||||
axis=0)
|
||||
center_bias_op = self._center_bias_fn(
|
||||
center_bias_var,
|
||||
array_ops.expand_dims(accumulated[0], 0),
|
||||
array_ops.expand_dims(accumulated[1], 0))
|
||||
with ops.control_dependencies([center_bias_op]):
|
||||
return control_flow_ops.cond(center_bias_var,
|
||||
control_flow_ops.no_op,
|
||||
_set_accumulators_stamp)
|
||||
|
||||
center_bias_op = control_flow_ops.cond(
|
||||
math_ops.greater_equal(bias_accumulator.num_accumulated(),
|
||||
math_ops.greater_equal(self._bias_accumulator.num_accumulated(),
|
||||
self._n_batches_per_layer),
|
||||
center_bias_from_accumulator,
|
||||
control_flow_ops.no_op,
|
||||
name='wait_until_n_batches_for_bias_accumulated')
|
||||
return center_bias_op
|
||||
|
||||
def grow_tree(self, stats_summaries_list, feature_ids_list,
|
||||
last_layer_nodes_range):
|
||||
# For not in memory situation, we need to accumulate enough of batches first
|
||||
# before proceeding with building a tree layer.
|
||||
max_splits = _get_max_splits(self._tree_hparams)
|
||||
|
||||
# Prepare accumulators.
|
||||
accumulators = []
|
||||
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
|
||||
dependencies = []
|
||||
for i, feature_ids in enumerate(feature_ids_list):
|
||||
for i in range(len(self._feature_ids_list)):
|
||||
stats_summaries = stats_summaries_list[i]
|
||||
accumulator = data_flow_ops.ConditionalAccumulator(
|
||||
dtype=dtypes.float32,
|
||||
# The stats consist of grads and hessians (the last dimension).
|
||||
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
|
||||
shared_name='numeric_stats_summary_accumulator_' + str(i))
|
||||
accumulators.append(accumulator)
|
||||
|
||||
apply_grad = accumulator.apply_grad(
|
||||
apply_grad = self._growing_accumulators[i].apply_grad(
|
||||
array_ops.stack(stats_summaries, axis=0), self._stamp_token)
|
||||
dependencies.append(apply_grad)
|
||||
|
||||
@ -617,7 +638,8 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
min_accumulated = math_ops.reduce_min(
|
||||
array_ops.stack([acc.num_accumulated() for acc in accumulators]))
|
||||
array_ops.stack([acc.num_accumulated() for acc in
|
||||
self._growing_accumulators]))
|
||||
|
||||
def grow_tree_from_accumulated_summaries_fn():
|
||||
"""Updates tree with the best layer from accumulated summaries."""
|
||||
@ -625,10 +647,11 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
||||
stats_summaries_list = []
|
||||
stats_summaries_list = [
|
||||
array_ops.unstack(accumulator.take_grad(1), axis=0)
|
||||
for accumulator in accumulators
|
||||
for accumulator in self._growing_accumulators
|
||||
]
|
||||
grow_op = self._grow_tree_from_stats_summaries(
|
||||
stats_summaries_list, feature_ids_list, last_layer_nodes_range)
|
||||
stats_summaries_list, last_layer_nodes_range
|
||||
)
|
||||
return grow_op
|
||||
|
||||
grow_model = control_flow_ops.cond(
|
||||
@ -638,6 +661,10 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
||||
name='wait_until_n_batches_accumulated')
|
||||
return grow_model
|
||||
|
||||
def chief_init_op(self):
|
||||
"""Ops that chief needs to run to initialize the state."""
|
||||
return control_flow_ops.group(self._chief_init_ops)
|
||||
|
||||
|
||||
def _bt_model_fn(
|
||||
features,
|
||||
@ -683,21 +710,7 @@ def _bt_model_fn(
|
||||
Raises:
|
||||
ValueError: mode or params are invalid, or features has the wrong type.
|
||||
"""
|
||||
is_single_machine = (config.num_worker_replicas <= 1)
|
||||
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
|
||||
center_bias = tree_hparams.center_bias
|
||||
|
||||
if train_in_memory:
|
||||
assert n_batches_per_layer == 1, (
|
||||
'When train_in_memory is enabled, input_fn should return the entire '
|
||||
'dataset as a single batch, and n_batches_per_layer should be set as '
|
||||
'1.')
|
||||
if (not config.is_chief or config.num_worker_replicas > 1 or
|
||||
config.num_ps_replicas > 0):
|
||||
raise ValueError('train_in_memory is supported only for '
|
||||
'non-distributed training.')
|
||||
worker_device = control_flow_ops.no_op().device
|
||||
train_op = []
|
||||
with ops.name_scope(name) as name:
|
||||
# Prepare.
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
@ -724,6 +737,20 @@ def _bt_model_fn(
|
||||
logits=logits)
|
||||
|
||||
# ============== Training graph ==============
|
||||
center_bias = tree_hparams.center_bias
|
||||
is_single_machine = (config.num_worker_replicas <= 1)
|
||||
|
||||
if train_in_memory:
|
||||
assert n_batches_per_layer == 1, (
|
||||
'When train_in_memory is enabled, input_fn should return the entire '
|
||||
'dataset as a single batch, and n_batches_per_layer should be set as '
|
||||
'1.')
|
||||
if (not config.is_chief or config.num_worker_replicas > 1 or
|
||||
config.num_ps_replicas > 0):
|
||||
raise ValueError('train_in_memory is supported only for '
|
||||
'non-distributed training.')
|
||||
worker_device = control_flow_ops.no_op().device
|
||||
train_op = []
|
||||
# Extract input features and set up cache for training.
|
||||
training_state_cache = None
|
||||
if train_in_memory:
|
||||
@ -742,22 +769,6 @@ def _bt_model_fn(
|
||||
example_ids = features[example_id_column_name]
|
||||
training_state_cache = _CacheTrainingStatesUsingHashTable(
|
||||
example_ids, head.logits_dimension)
|
||||
|
||||
# Variable that determines whether bias centering is needed.
|
||||
center_bias_var = variable_scope.variable(
|
||||
initial_value=center_bias, name='center_bias_needed', trainable=False)
|
||||
if is_single_machine:
|
||||
local_tree_ensemble = tree_ensemble
|
||||
ensemble_reload = control_flow_ops.no_op()
|
||||
else:
|
||||
# Have a local copy of ensemble for the distributed setting.
|
||||
with ops.device(worker_device):
|
||||
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
name=name + '_local', is_local=True)
|
||||
# TODO(soroush): Do partial updates if this becomes a bottleneck.
|
||||
ensemble_reload = local_tree_ensemble.deserialize(
|
||||
*tree_ensemble.serialize())
|
||||
|
||||
if training_state_cache:
|
||||
cached_tree_ids, cached_node_ids, cached_logits = (
|
||||
training_state_cache.lookup())
|
||||
@ -770,13 +781,20 @@ def _bt_model_fn(
|
||||
array_ops.zeros(
|
||||
[batch_size, head.logits_dimension], dtype=dtypes.float32))
|
||||
|
||||
if is_single_machine:
|
||||
local_tree_ensemble = tree_ensemble
|
||||
ensemble_reload = control_flow_ops.no_op()
|
||||
else:
|
||||
# Have a local copy of ensemble for the distributed setting.
|
||||
with ops.device(worker_device):
|
||||
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||
name=name + '_local', is_local=True)
|
||||
# TODO(soroush): Do partial updates if this becomes a bottleneck.
|
||||
ensemble_reload = local_tree_ensemble.deserialize(
|
||||
*tree_ensemble.serialize())
|
||||
with ops.control_dependencies([ensemble_reload]):
|
||||
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
|
||||
last_layer_nodes_range) = local_tree_ensemble.get_states()
|
||||
summary.scalar('ensemble/num_trees', num_trees)
|
||||
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
|
||||
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
|
||||
|
||||
partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
|
||||
tree_ensemble_handle=local_tree_ensemble.resource_handle,
|
||||
cached_tree_ids=cached_tree_ids,
|
||||
@ -785,6 +803,24 @@ def _bt_model_fn(
|
||||
logits_dimension=head.logits_dimension)
|
||||
logits = cached_logits + partial_logits
|
||||
|
||||
if train_in_memory:
|
||||
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams,
|
||||
feature_ids_list=feature_ids_list)
|
||||
else:
|
||||
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
|
||||
stamp_token, n_batches_per_layer,
|
||||
bucket_size_list, config.is_chief,
|
||||
center_bias=center_bias,
|
||||
feature_ids_list=feature_ids_list)
|
||||
|
||||
summary.scalar('ensemble/num_trees', num_trees)
|
||||
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
|
||||
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
|
||||
|
||||
# Variable that determines whether bias centering is needed.
|
||||
center_bias_var = variable_scope.variable(
|
||||
initial_value=center_bias, name='center_bias_needed', trainable=False,
|
||||
use_resource=True)
|
||||
# Create training graph.
|
||||
def _train_op_fn(loss):
|
||||
"""Run one training iteration."""
|
||||
@ -823,14 +859,7 @@ def _bt_model_fn(
|
||||
axis=0) for f in feature_ids
|
||||
]
|
||||
stats_summaries_list.append(summaries)
|
||||
|
||||
if train_in_memory and is_single_machine:
|
||||
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
|
||||
else:
|
||||
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
|
||||
stamp_token, n_batches_per_layer,
|
||||
bucket_size_list, config.is_chief)
|
||||
|
||||
if center_bias:
|
||||
update_model = control_flow_ops.cond(
|
||||
center_bias_var,
|
||||
functools.partial(
|
||||
@ -840,7 +869,10 @@ def _bt_model_fn(
|
||||
hessians,
|
||||
),
|
||||
functools.partial(grower.grow_tree, stats_summaries_list,
|
||||
feature_ids_list, last_layer_nodes_range))
|
||||
last_layer_nodes_range))
|
||||
else:
|
||||
update_model = grower.grow_tree(stats_summaries_list,
|
||||
last_layer_nodes_range)
|
||||
train_op.append(update_model)
|
||||
|
||||
with ops.control_dependencies([update_model]):
|
||||
@ -859,10 +891,22 @@ def _bt_model_fn(
|
||||
estimator_spec = estimator_spec._replace(
|
||||
training_hooks=estimator_spec.training_hooks +
|
||||
(_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
|
||||
tree_hparams.n_trees, tree_hparams.max_depth),))
|
||||
tree_hparams.n_trees, tree_hparams.max_depth),),
|
||||
training_chief_hooks=[GrowerInitializationHook(grower.chief_init_op())] +
|
||||
list(estimator_spec.training_chief_hooks))
|
||||
return estimator_spec
|
||||
|
||||
|
||||
class GrowerInitializationHook(session_run_hook.SessionRunHook):
|
||||
"""A SessionRunHook handles initialization of `_EnsembleGrower`."""
|
||||
|
||||
def __init__(self, init_op):
|
||||
self._init_op = init_op
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
session.run(self._init_op)
|
||||
|
||||
|
||||
def _create_classification_head(n_classes,
|
||||
weight_column=None,
|
||||
label_vocabulary=None):
|
||||
|
@ -173,6 +173,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
|
||||
eval_res = est.evaluate(input_fn=input_fn, steps=1)
|
||||
self.assertAllClose(eval_res['accuracy'], 1.0)
|
||||
|
||||
def testTrainTwiceAndEvaluateBinaryClassifier(self):
|
||||
input_fn = _make_train_input_fn(is_classification=True)
|
||||
|
||||
est = boosted_trees.BoostedTreesClassifier(
|
||||
feature_columns=self._feature_columns,
|
||||
n_batches_per_layer=1,
|
||||
n_trees=5,
|
||||
max_depth=10)
|
||||
|
||||
num_steps = 2
|
||||
# Train for a few steps, and validate final checkpoint.
|
||||
est.train(input_fn, steps=num_steps)
|
||||
est.train(input_fn, steps=num_steps)
|
||||
|
||||
self._assert_checkpoint(
|
||||
est.model_dir, global_step=num_steps * 2,
|
||||
finalized_trees=0, attempted_layers=4)
|
||||
eval_res = est.evaluate(input_fn=input_fn, steps=1)
|
||||
self.assertAllClose(eval_res['accuracy'], 1.0)
|
||||
|
||||
def testInferBinaryClassifier(self):
|
||||
train_input_fn = _make_train_input_fn(is_classification=True)
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
|
Loading…
Reference in New Issue
Block a user