Synchronize stamp token between different resources:
* During bias_centering, the global_step on per-feature accumulators are not updated, so once the bias centering is done, we now update the accumulators to have in-sync global step. * When a parameter server is restarted, the accumulators need to be synchronized with the stamp token after checkpoint is reloaded. PiperOrigin-RevId: 209617008
This commit is contained in:
parent
212d978a2d
commit
792a933b11
@ -404,18 +404,21 @@ class _EnsembleGrower(object):
|
|||||||
training_ops.append(grow_op)
|
training_ops.append(grow_op)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tree_ensemble, tree_hparams):
|
def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
|
||||||
"""Initializes a grower object.
|
"""Initializes a grower object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree_ensemble: A TreeEnsemble variable.
|
tree_ensemble: A TreeEnsemble variable.
|
||||||
tree_hparams: TODO. collections.namedtuple for hyper parameters.
|
tree_hparams: TODO. collections.namedtuple for hyper parameters.
|
||||||
|
feature_ids_list: a list of lists of feature ids for each bucket size.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: when pruning mode is invalid or pruning is used and no tree
|
ValueError: when pruning mode is invalid or pruning is used and no tree
|
||||||
complexity is set.
|
complexity is set.
|
||||||
"""
|
"""
|
||||||
self._tree_ensemble = tree_ensemble
|
self._tree_ensemble = tree_ensemble
|
||||||
self._tree_hparams = tree_hparams
|
self._tree_hparams = tree_hparams
|
||||||
|
self._feature_ids_list = feature_ids_list
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
|
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
|
||||||
tree_hparams.pruning_mode)
|
tree_hparams.pruning_mode)
|
||||||
@ -440,14 +443,12 @@ class _EnsembleGrower(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def grow_tree(self, stats_summaries_list, feature_ids_list,
|
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
|
||||||
last_layer_nodes_range):
|
|
||||||
"""Grows a tree, if ready, based on provided statistics.
|
"""Grows a tree, if ready, based on provided statistics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stats_summaries_list: List of stats summary tensors, representing sums of
|
stats_summaries_list: List of stats summary tensors, representing sums of
|
||||||
gradients and hessians for each feature bucket.
|
gradients and hessians for each feature bucket.
|
||||||
feature_ids_list: a list of lists of feature ids for each bucket size.
|
|
||||||
last_layer_nodes_range: A tensor representing ids of the nodes in the
|
last_layer_nodes_range: A tensor representing ids of the nodes in the
|
||||||
current layer, to be split.
|
current layer, to be split.
|
||||||
|
|
||||||
@ -455,6 +456,10 @@ class _EnsembleGrower(object):
|
|||||||
An op for growing a tree.
|
An op for growing a tree.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def chief_init_op(self):
|
||||||
|
"""Ops that chief needs to run to initialize the state."""
|
||||||
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
# ============= Helper methods ===========
|
# ============= Helper methods ===========
|
||||||
|
|
||||||
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
|
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
|
||||||
@ -468,7 +473,7 @@ class _EnsembleGrower(object):
|
|||||||
return center_bias_var.assign(continue_centering)
|
return center_bias_var.assign(continue_centering)
|
||||||
|
|
||||||
def _grow_tree_from_stats_summaries(self, stats_summaries_list,
|
def _grow_tree_from_stats_summaries(self, stats_summaries_list,
|
||||||
feature_ids_list, last_layer_nodes_range):
|
last_layer_nodes_range):
|
||||||
"""Updates ensemble based on the best gains from stats summaries."""
|
"""Updates ensemble based on the best gains from stats summaries."""
|
||||||
node_ids_per_feature = []
|
node_ids_per_feature = []
|
||||||
gains_list = []
|
gains_list = []
|
||||||
@ -476,11 +481,11 @@ class _EnsembleGrower(object):
|
|||||||
left_node_contribs_list = []
|
left_node_contribs_list = []
|
||||||
right_node_contribs_list = []
|
right_node_contribs_list = []
|
||||||
all_feature_ids = []
|
all_feature_ids = []
|
||||||
assert len(stats_summaries_list) == len(feature_ids_list)
|
assert len(stats_summaries_list) == len(self._feature_ids_list)
|
||||||
|
|
||||||
max_splits = _get_max_splits(self._tree_hparams)
|
max_splits = _get_max_splits(self._tree_hparams)
|
||||||
|
|
||||||
for i, feature_ids in enumerate(feature_ids_list):
|
for i, feature_ids in enumerate(self._feature_ids_list):
|
||||||
(numeric_node_ids_per_feature, numeric_gains_list,
|
(numeric_node_ids_per_feature, numeric_gains_list,
|
||||||
numeric_thresholds_list, numeric_left_node_contribs_list,
|
numeric_thresholds_list, numeric_left_node_contribs_list,
|
||||||
numeric_right_node_contribs_list) = (
|
numeric_right_node_contribs_list) = (
|
||||||
@ -516,12 +521,13 @@ class _EnsembleGrower(object):
|
|||||||
|
|
||||||
|
|
||||||
class _InMemoryEnsembleGrower(_EnsembleGrower):
|
class _InMemoryEnsembleGrower(_EnsembleGrower):
|
||||||
"""A base class for ensemble growers."""
|
"""An in-memory ensemble grower."""
|
||||||
|
|
||||||
def __init__(self, tree_ensemble, tree_hparams):
|
def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
|
||||||
|
|
||||||
super(_InMemoryEnsembleGrower, self).__init__(
|
super(_InMemoryEnsembleGrower, self).__init__(
|
||||||
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
|
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
|
||||||
|
feature_ids_list=feature_ids_list)
|
||||||
|
|
||||||
def center_bias(self, center_bias_var, gradients, hessians):
|
def center_bias(self, center_bias_var, gradients, hessians):
|
||||||
# For in memory, we already have a full batch of gradients and hessians,
|
# For in memory, we already have a full batch of gradients and hessians,
|
||||||
@ -531,83 +537,98 @@ class _InMemoryEnsembleGrower(_EnsembleGrower):
|
|||||||
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
|
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
|
||||||
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
|
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
|
||||||
|
|
||||||
def grow_tree(self, stats_summaries_list, feature_ids_list,
|
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
|
||||||
last_layer_nodes_range):
|
|
||||||
# For in memory, we already have full data in one batch, so we can grow the
|
# For in memory, we already have full data in one batch, so we can grow the
|
||||||
# tree immediately.
|
# tree immediately.
|
||||||
return self._grow_tree_from_stats_summaries(
|
return self._grow_tree_from_stats_summaries(
|
||||||
stats_summaries_list, feature_ids_list, last_layer_nodes_range)
|
stats_summaries_list, last_layer_nodes_range)
|
||||||
|
|
||||||
|
|
||||||
class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
||||||
"""A base class for ensemble growers."""
|
"""An accumulator based ensemble grower."""
|
||||||
|
|
||||||
def __init__(self, tree_ensemble, tree_hparams, stamp_token,
|
def __init__(self, tree_ensemble, tree_hparams, stamp_token,
|
||||||
n_batches_per_layer, bucket_size_list, is_chief):
|
n_batches_per_layer, bucket_size_list, is_chief, center_bias,
|
||||||
|
feature_ids_list):
|
||||||
super(_AccumulatorEnsembleGrower, self).__init__(
|
super(_AccumulatorEnsembleGrower, self).__init__(
|
||||||
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
|
tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
|
||||||
|
feature_ids_list=feature_ids_list)
|
||||||
self._stamp_token = stamp_token
|
self._stamp_token = stamp_token
|
||||||
self._n_batches_per_layer = n_batches_per_layer
|
self._n_batches_per_layer = n_batches_per_layer
|
||||||
self._bucket_size_list = bucket_size_list
|
self._bucket_size_list = bucket_size_list
|
||||||
self._is_chief = is_chief
|
self._is_chief = is_chief
|
||||||
|
self._growing_accumulators = []
|
||||||
|
self._chief_init_ops = []
|
||||||
|
max_splits = _get_max_splits(self._tree_hparams)
|
||||||
|
for i, feature_ids in enumerate(self._feature_ids_list):
|
||||||
|
accumulator = data_flow_ops.ConditionalAccumulator(
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
# The stats consist of grads and hessians (the last dimension).
|
||||||
|
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
|
||||||
|
shared_name='numeric_stats_summary_accumulator_' + str(i))
|
||||||
|
self._chief_init_ops.append(
|
||||||
|
accumulator.set_global_step(self._stamp_token))
|
||||||
|
self._growing_accumulators.append(accumulator)
|
||||||
|
self._center_bias = center_bias
|
||||||
|
if center_bias:
|
||||||
|
self._bias_accumulator = data_flow_ops.ConditionalAccumulator(
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
# The stats consist of grads and hessians means only.
|
||||||
|
# TODO(nponomareva): this will change for a multiclass
|
||||||
|
shape=[2, 1],
|
||||||
|
shared_name='bias_accumulator')
|
||||||
|
self._chief_init_ops.append(
|
||||||
|
self._bias_accumulator.set_global_step(self._stamp_token))
|
||||||
|
|
||||||
def center_bias(self, center_bias_var, gradients, hessians):
|
def center_bias(self, center_bias_var, gradients, hessians):
|
||||||
# For not in memory situation, we need to accumulate enough of batches first
|
# For not in memory situation, we need to accumulate enough of batches first
|
||||||
# before proceeding with centering bias.
|
# before proceeding with centering bias.
|
||||||
|
|
||||||
# Create an accumulator.
|
# Create an accumulator.
|
||||||
|
if not self._center_bias:
|
||||||
|
raise RuntimeError('center_bias called but bias centering is disabled.')
|
||||||
bias_dependencies = []
|
bias_dependencies = []
|
||||||
bias_accumulator = data_flow_ops.ConditionalAccumulator(
|
|
||||||
dtype=dtypes.float32,
|
|
||||||
# The stats consist of grads and hessians means only.
|
|
||||||
# TODO(nponomareva): this will change for a multiclass
|
|
||||||
shape=[2, 1],
|
|
||||||
shared_name='bias_accumulator')
|
|
||||||
|
|
||||||
grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
|
grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
|
||||||
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
|
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
|
||||||
|
|
||||||
apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
|
apply_grad = self._bias_accumulator.apply_grad(
|
||||||
|
grads_and_hess, self._stamp_token)
|
||||||
bias_dependencies.append(apply_grad)
|
bias_dependencies.append(apply_grad)
|
||||||
|
|
||||||
# Center bias if enough batches were processed.
|
# Center bias if enough batches were processed.
|
||||||
with ops.control_dependencies(bias_dependencies):
|
with ops.control_dependencies(bias_dependencies):
|
||||||
if not self._is_chief:
|
if not self._is_chief:
|
||||||
return control_flow_ops.no_op()
|
return control_flow_ops.no_op()
|
||||||
|
def _set_accumulators_stamp():
|
||||||
|
return control_flow_ops.group(
|
||||||
|
[acc.set_global_step(self._stamp_token + 1) for acc in
|
||||||
|
self._growing_accumulators])
|
||||||
|
|
||||||
def center_bias_from_accumulator():
|
def center_bias_from_accumulator():
|
||||||
accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
|
accumulated = array_ops.unstack(self._bias_accumulator.take_grad(1),
|
||||||
return self._center_bias_fn(center_bias_var,
|
axis=0)
|
||||||
|
center_bias_op = self._center_bias_fn(
|
||||||
|
center_bias_var,
|
||||||
array_ops.expand_dims(accumulated[0], 0),
|
array_ops.expand_dims(accumulated[0], 0),
|
||||||
array_ops.expand_dims(accumulated[1], 0))
|
array_ops.expand_dims(accumulated[1], 0))
|
||||||
|
with ops.control_dependencies([center_bias_op]):
|
||||||
|
return control_flow_ops.cond(center_bias_var,
|
||||||
|
control_flow_ops.no_op,
|
||||||
|
_set_accumulators_stamp)
|
||||||
|
|
||||||
center_bias_op = control_flow_ops.cond(
|
center_bias_op = control_flow_ops.cond(
|
||||||
math_ops.greater_equal(bias_accumulator.num_accumulated(),
|
math_ops.greater_equal(self._bias_accumulator.num_accumulated(),
|
||||||
self._n_batches_per_layer),
|
self._n_batches_per_layer),
|
||||||
center_bias_from_accumulator,
|
center_bias_from_accumulator,
|
||||||
control_flow_ops.no_op,
|
control_flow_ops.no_op,
|
||||||
name='wait_until_n_batches_for_bias_accumulated')
|
name='wait_until_n_batches_for_bias_accumulated')
|
||||||
return center_bias_op
|
return center_bias_op
|
||||||
|
|
||||||
def grow_tree(self, stats_summaries_list, feature_ids_list,
|
def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
|
||||||
last_layer_nodes_range):
|
|
||||||
# For not in memory situation, we need to accumulate enough of batches first
|
|
||||||
# before proceeding with building a tree layer.
|
|
||||||
max_splits = _get_max_splits(self._tree_hparams)
|
|
||||||
|
|
||||||
# Prepare accumulators.
|
|
||||||
accumulators = []
|
|
||||||
dependencies = []
|
dependencies = []
|
||||||
for i, feature_ids in enumerate(feature_ids_list):
|
for i in range(len(self._feature_ids_list)):
|
||||||
stats_summaries = stats_summaries_list[i]
|
stats_summaries = stats_summaries_list[i]
|
||||||
accumulator = data_flow_ops.ConditionalAccumulator(
|
apply_grad = self._growing_accumulators[i].apply_grad(
|
||||||
dtype=dtypes.float32,
|
|
||||||
# The stats consist of grads and hessians (the last dimension).
|
|
||||||
shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
|
|
||||||
shared_name='numeric_stats_summary_accumulator_' + str(i))
|
|
||||||
accumulators.append(accumulator)
|
|
||||||
|
|
||||||
apply_grad = accumulator.apply_grad(
|
|
||||||
array_ops.stack(stats_summaries, axis=0), self._stamp_token)
|
array_ops.stack(stats_summaries, axis=0), self._stamp_token)
|
||||||
dependencies.append(apply_grad)
|
dependencies.append(apply_grad)
|
||||||
|
|
||||||
@ -617,7 +638,8 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
|||||||
return control_flow_ops.no_op()
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
min_accumulated = math_ops.reduce_min(
|
min_accumulated = math_ops.reduce_min(
|
||||||
array_ops.stack([acc.num_accumulated() for acc in accumulators]))
|
array_ops.stack([acc.num_accumulated() for acc in
|
||||||
|
self._growing_accumulators]))
|
||||||
|
|
||||||
def grow_tree_from_accumulated_summaries_fn():
|
def grow_tree_from_accumulated_summaries_fn():
|
||||||
"""Updates tree with the best layer from accumulated summaries."""
|
"""Updates tree with the best layer from accumulated summaries."""
|
||||||
@ -625,10 +647,11 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
|||||||
stats_summaries_list = []
|
stats_summaries_list = []
|
||||||
stats_summaries_list = [
|
stats_summaries_list = [
|
||||||
array_ops.unstack(accumulator.take_grad(1), axis=0)
|
array_ops.unstack(accumulator.take_grad(1), axis=0)
|
||||||
for accumulator in accumulators
|
for accumulator in self._growing_accumulators
|
||||||
]
|
]
|
||||||
grow_op = self._grow_tree_from_stats_summaries(
|
grow_op = self._grow_tree_from_stats_summaries(
|
||||||
stats_summaries_list, feature_ids_list, last_layer_nodes_range)
|
stats_summaries_list, last_layer_nodes_range
|
||||||
|
)
|
||||||
return grow_op
|
return grow_op
|
||||||
|
|
||||||
grow_model = control_flow_ops.cond(
|
grow_model = control_flow_ops.cond(
|
||||||
@ -638,6 +661,10 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
|
|||||||
name='wait_until_n_batches_accumulated')
|
name='wait_until_n_batches_accumulated')
|
||||||
return grow_model
|
return grow_model
|
||||||
|
|
||||||
|
def chief_init_op(self):
|
||||||
|
"""Ops that chief needs to run to initialize the state."""
|
||||||
|
return control_flow_ops.group(self._chief_init_ops)
|
||||||
|
|
||||||
|
|
||||||
def _bt_model_fn(
|
def _bt_model_fn(
|
||||||
features,
|
features,
|
||||||
@ -683,21 +710,7 @@ def _bt_model_fn(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: mode or params are invalid, or features has the wrong type.
|
ValueError: mode or params are invalid, or features has the wrong type.
|
||||||
"""
|
"""
|
||||||
is_single_machine = (config.num_worker_replicas <= 1)
|
|
||||||
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
|
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
|
||||||
center_bias = tree_hparams.center_bias
|
|
||||||
|
|
||||||
if train_in_memory:
|
|
||||||
assert n_batches_per_layer == 1, (
|
|
||||||
'When train_in_memory is enabled, input_fn should return the entire '
|
|
||||||
'dataset as a single batch, and n_batches_per_layer should be set as '
|
|
||||||
'1.')
|
|
||||||
if (not config.is_chief or config.num_worker_replicas > 1 or
|
|
||||||
config.num_ps_replicas > 0):
|
|
||||||
raise ValueError('train_in_memory is supported only for '
|
|
||||||
'non-distributed training.')
|
|
||||||
worker_device = control_flow_ops.no_op().device
|
|
||||||
train_op = []
|
|
||||||
with ops.name_scope(name) as name:
|
with ops.name_scope(name) as name:
|
||||||
# Prepare.
|
# Prepare.
|
||||||
global_step = training_util.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
@ -724,6 +737,20 @@ def _bt_model_fn(
|
|||||||
logits=logits)
|
logits=logits)
|
||||||
|
|
||||||
# ============== Training graph ==============
|
# ============== Training graph ==============
|
||||||
|
center_bias = tree_hparams.center_bias
|
||||||
|
is_single_machine = (config.num_worker_replicas <= 1)
|
||||||
|
|
||||||
|
if train_in_memory:
|
||||||
|
assert n_batches_per_layer == 1, (
|
||||||
|
'When train_in_memory is enabled, input_fn should return the entire '
|
||||||
|
'dataset as a single batch, and n_batches_per_layer should be set as '
|
||||||
|
'1.')
|
||||||
|
if (not config.is_chief or config.num_worker_replicas > 1 or
|
||||||
|
config.num_ps_replicas > 0):
|
||||||
|
raise ValueError('train_in_memory is supported only for '
|
||||||
|
'non-distributed training.')
|
||||||
|
worker_device = control_flow_ops.no_op().device
|
||||||
|
train_op = []
|
||||||
# Extract input features and set up cache for training.
|
# Extract input features and set up cache for training.
|
||||||
training_state_cache = None
|
training_state_cache = None
|
||||||
if train_in_memory:
|
if train_in_memory:
|
||||||
@ -742,22 +769,6 @@ def _bt_model_fn(
|
|||||||
example_ids = features[example_id_column_name]
|
example_ids = features[example_id_column_name]
|
||||||
training_state_cache = _CacheTrainingStatesUsingHashTable(
|
training_state_cache = _CacheTrainingStatesUsingHashTable(
|
||||||
example_ids, head.logits_dimension)
|
example_ids, head.logits_dimension)
|
||||||
|
|
||||||
# Variable that determines whether bias centering is needed.
|
|
||||||
center_bias_var = variable_scope.variable(
|
|
||||||
initial_value=center_bias, name='center_bias_needed', trainable=False)
|
|
||||||
if is_single_machine:
|
|
||||||
local_tree_ensemble = tree_ensemble
|
|
||||||
ensemble_reload = control_flow_ops.no_op()
|
|
||||||
else:
|
|
||||||
# Have a local copy of ensemble for the distributed setting.
|
|
||||||
with ops.device(worker_device):
|
|
||||||
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
|
||||||
name=name + '_local', is_local=True)
|
|
||||||
# TODO(soroush): Do partial updates if this becomes a bottleneck.
|
|
||||||
ensemble_reload = local_tree_ensemble.deserialize(
|
|
||||||
*tree_ensemble.serialize())
|
|
||||||
|
|
||||||
if training_state_cache:
|
if training_state_cache:
|
||||||
cached_tree_ids, cached_node_ids, cached_logits = (
|
cached_tree_ids, cached_node_ids, cached_logits = (
|
||||||
training_state_cache.lookup())
|
training_state_cache.lookup())
|
||||||
@ -770,13 +781,20 @@ def _bt_model_fn(
|
|||||||
array_ops.zeros(
|
array_ops.zeros(
|
||||||
[batch_size, head.logits_dimension], dtype=dtypes.float32))
|
[batch_size, head.logits_dimension], dtype=dtypes.float32))
|
||||||
|
|
||||||
|
if is_single_machine:
|
||||||
|
local_tree_ensemble = tree_ensemble
|
||||||
|
ensemble_reload = control_flow_ops.no_op()
|
||||||
|
else:
|
||||||
|
# Have a local copy of ensemble for the distributed setting.
|
||||||
|
with ops.device(worker_device):
|
||||||
|
local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||||
|
name=name + '_local', is_local=True)
|
||||||
|
# TODO(soroush): Do partial updates if this becomes a bottleneck.
|
||||||
|
ensemble_reload = local_tree_ensemble.deserialize(
|
||||||
|
*tree_ensemble.serialize())
|
||||||
with ops.control_dependencies([ensemble_reload]):
|
with ops.control_dependencies([ensemble_reload]):
|
||||||
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
|
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
|
||||||
last_layer_nodes_range) = local_tree_ensemble.get_states()
|
last_layer_nodes_range) = local_tree_ensemble.get_states()
|
||||||
summary.scalar('ensemble/num_trees', num_trees)
|
|
||||||
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
|
|
||||||
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
|
|
||||||
|
|
||||||
partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
|
partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
|
||||||
tree_ensemble_handle=local_tree_ensemble.resource_handle,
|
tree_ensemble_handle=local_tree_ensemble.resource_handle,
|
||||||
cached_tree_ids=cached_tree_ids,
|
cached_tree_ids=cached_tree_ids,
|
||||||
@ -785,6 +803,24 @@ def _bt_model_fn(
|
|||||||
logits_dimension=head.logits_dimension)
|
logits_dimension=head.logits_dimension)
|
||||||
logits = cached_logits + partial_logits
|
logits = cached_logits + partial_logits
|
||||||
|
|
||||||
|
if train_in_memory:
|
||||||
|
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams,
|
||||||
|
feature_ids_list=feature_ids_list)
|
||||||
|
else:
|
||||||
|
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
|
||||||
|
stamp_token, n_batches_per_layer,
|
||||||
|
bucket_size_list, config.is_chief,
|
||||||
|
center_bias=center_bias,
|
||||||
|
feature_ids_list=feature_ids_list)
|
||||||
|
|
||||||
|
summary.scalar('ensemble/num_trees', num_trees)
|
||||||
|
summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
|
||||||
|
summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
|
||||||
|
|
||||||
|
# Variable that determines whether bias centering is needed.
|
||||||
|
center_bias_var = variable_scope.variable(
|
||||||
|
initial_value=center_bias, name='center_bias_needed', trainable=False,
|
||||||
|
use_resource=True)
|
||||||
# Create training graph.
|
# Create training graph.
|
||||||
def _train_op_fn(loss):
|
def _train_op_fn(loss):
|
||||||
"""Run one training iteration."""
|
"""Run one training iteration."""
|
||||||
@ -823,14 +859,7 @@ def _bt_model_fn(
|
|||||||
axis=0) for f in feature_ids
|
axis=0) for f in feature_ids
|
||||||
]
|
]
|
||||||
stats_summaries_list.append(summaries)
|
stats_summaries_list.append(summaries)
|
||||||
|
if center_bias:
|
||||||
if train_in_memory and is_single_machine:
|
|
||||||
grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
|
|
||||||
else:
|
|
||||||
grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
|
|
||||||
stamp_token, n_batches_per_layer,
|
|
||||||
bucket_size_list, config.is_chief)
|
|
||||||
|
|
||||||
update_model = control_flow_ops.cond(
|
update_model = control_flow_ops.cond(
|
||||||
center_bias_var,
|
center_bias_var,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
@ -840,7 +869,10 @@ def _bt_model_fn(
|
|||||||
hessians,
|
hessians,
|
||||||
),
|
),
|
||||||
functools.partial(grower.grow_tree, stats_summaries_list,
|
functools.partial(grower.grow_tree, stats_summaries_list,
|
||||||
feature_ids_list, last_layer_nodes_range))
|
last_layer_nodes_range))
|
||||||
|
else:
|
||||||
|
update_model = grower.grow_tree(stats_summaries_list,
|
||||||
|
last_layer_nodes_range)
|
||||||
train_op.append(update_model)
|
train_op.append(update_model)
|
||||||
|
|
||||||
with ops.control_dependencies([update_model]):
|
with ops.control_dependencies([update_model]):
|
||||||
@ -859,10 +891,22 @@ def _bt_model_fn(
|
|||||||
estimator_spec = estimator_spec._replace(
|
estimator_spec = estimator_spec._replace(
|
||||||
training_hooks=estimator_spec.training_hooks +
|
training_hooks=estimator_spec.training_hooks +
|
||||||
(_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
|
(_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
|
||||||
tree_hparams.n_trees, tree_hparams.max_depth),))
|
tree_hparams.n_trees, tree_hparams.max_depth),),
|
||||||
|
training_chief_hooks=[GrowerInitializationHook(grower.chief_init_op())] +
|
||||||
|
list(estimator_spec.training_chief_hooks))
|
||||||
return estimator_spec
|
return estimator_spec
|
||||||
|
|
||||||
|
|
||||||
|
class GrowerInitializationHook(session_run_hook.SessionRunHook):
|
||||||
|
"""A SessionRunHook handles initialization of `_EnsembleGrower`."""
|
||||||
|
|
||||||
|
def __init__(self, init_op):
|
||||||
|
self._init_op = init_op
|
||||||
|
|
||||||
|
def after_create_session(self, session, coord):
|
||||||
|
session.run(self._init_op)
|
||||||
|
|
||||||
|
|
||||||
def _create_classification_head(n_classes,
|
def _create_classification_head(n_classes,
|
||||||
weight_column=None,
|
weight_column=None,
|
||||||
label_vocabulary=None):
|
label_vocabulary=None):
|
||||||
|
@ -173,6 +173,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
|
|||||||
eval_res = est.evaluate(input_fn=input_fn, steps=1)
|
eval_res = est.evaluate(input_fn=input_fn, steps=1)
|
||||||
self.assertAllClose(eval_res['accuracy'], 1.0)
|
self.assertAllClose(eval_res['accuracy'], 1.0)
|
||||||
|
|
||||||
|
def testTrainTwiceAndEvaluateBinaryClassifier(self):
|
||||||
|
input_fn = _make_train_input_fn(is_classification=True)
|
||||||
|
|
||||||
|
est = boosted_trees.BoostedTreesClassifier(
|
||||||
|
feature_columns=self._feature_columns,
|
||||||
|
n_batches_per_layer=1,
|
||||||
|
n_trees=5,
|
||||||
|
max_depth=10)
|
||||||
|
|
||||||
|
num_steps = 2
|
||||||
|
# Train for a few steps, and validate final checkpoint.
|
||||||
|
est.train(input_fn, steps=num_steps)
|
||||||
|
est.train(input_fn, steps=num_steps)
|
||||||
|
|
||||||
|
self._assert_checkpoint(
|
||||||
|
est.model_dir, global_step=num_steps * 2,
|
||||||
|
finalized_trees=0, attempted_layers=4)
|
||||||
|
eval_res = est.evaluate(input_fn=input_fn, steps=1)
|
||||||
|
self.assertAllClose(eval_res['accuracy'], 1.0)
|
||||||
|
|
||||||
def testInferBinaryClassifier(self):
|
def testInferBinaryClassifier(self):
|
||||||
train_input_fn = _make_train_input_fn(is_classification=True)
|
train_input_fn = _make_train_input_fn(is_classification=True)
|
||||||
predict_input_fn = numpy_io.numpy_input_fn(
|
predict_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
Loading…
Reference in New Issue
Block a user