Providing a core estimator interface over a contrib tensorforest.

PiperOrigin-RevId: 208658097
This commit is contained in:
A. Unique TensorFlower 2018-08-14 09:21:03 -07:00 committed by TensorFlower Gardener
parent d7f93284c8
commit b03f732b3f
2 changed files with 563 additions and 64 deletions

View File

@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.estimator.canned import head as core_head_lib
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all'
EPSILON = 0.000001
class ModelBuilderOutputType(object):
MODEL_FN_OPS = 0
ESTIMATOR_SPEC = 1
class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
def __init__(self, op_dict):
@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
run_context.request_stop()
def get_default_head(params, weights_name, name=None):
if params.regression:
return head_lib.regression_head(
weight_column_name=weights_name,
label_dimension=params.num_outputs,
enable_centered_bias=False,
head_name=name)
def _get_default_head(params, weights_name, output_type, name=None):
"""Creates a default head based on a type of a problem."""
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if params.regression:
return head_lib.regression_head(
weight_column_name=weights_name,
label_dimension=params.num_outputs,
enable_centered_bias=False,
head_name=name)
else:
return head_lib.multi_class_head(
params.num_classes,
weight_column_name=weights_name,
enable_centered_bias=False,
head_name=name)
else:
return head_lib.multi_class_head(
params.num_classes,
weight_column_name=weights_name,
enable_centered_bias=False,
head_name=name)
if params.regression:
return core_head_lib._regression_head( # pylint:disable=protected-access
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
else:
return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
n_classes=params.num_classes,
weight_column=weights_name,
name=name,
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
def get_model_fn(params,
graph_builder_class,
@ -135,19 +156,27 @@ def get_model_fn(params,
report_feature_importances=False,
local_eval=False,
head_scope=None,
include_all_in_serving=False):
include_all_in_serving=False,
output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Return a model function given a way to construct a graph builder."""
if model_head is None:
model_head = get_default_head(params, weights_name)
model_head = _get_default_head(params, weights_name, output_type)
def _model_fn(features, labels, mode):
"""Function that returns predictions, training loss, and training op."""
if (isinstance(features, ops.Tensor) or
isinstance(features, sparse_tensor.SparseTensor)):
features = {'features': features}
if feature_columns:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
features.update(layers.transform_features(features, feature_columns))
else:
for fc in feature_columns:
tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access
features[fc.name] = tensor
weights = None
if weights_name and weights_name in features:
@ -201,52 +230,95 @@ def get_model_fn(params,
def _train_fn(unused_loss):
return training_graph
model_ops = model_head.create_model_fn_ops(
features=features,
labels=labels,
mode=mode,
train_op_fn=_train_fn,
logits=logits,
scope=head_scope)
# Ops are run in lexigraphical order of their keys. Run the resource
# clean-up op last.
all_handles = graph_builder.get_all_resource_handles()
ops_at_end = {
'9: clean up resources': control_flow_ops.group(
*[resource_variable_ops.destroy_resource_op(handle)
for handle in all_handles])}
'9: clean up resources':
control_flow_ops.group(*[
resource_variable_ops.destroy_resource_op(handle)
for handle in all_handles
])
}
if report_feature_importances:
ops_at_end['1: feature_importances'] = (
graph_builder.feature_importances())
training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)]
if early_stopping_rounds:
training_hooks.append(
TensorForestLossHook(
early_stopping_rounds,
early_stopping_loss_threshold=early_stopping_loss_threshold,
loss_op=model_ops.loss))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
model_ops = model_head.create_model_fn_ops(
features=features,
labels=labels,
mode=mode,
train_op_fn=_train_fn,
logits=logits,
scope=head_scope)
model_ops.training_hooks.extend(training_hooks)
if early_stopping_rounds:
training_hooks.append(
TensorForestLossHook(
early_stopping_rounds,
early_stopping_loss_threshold=early_stopping_loss_threshold,
loss_op=model_ops.loss))
if keys is not None:
model_ops.predictions[keys_name] = keys
model_ops.training_hooks.extend(training_hooks)
if params.inference_tree_paths:
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
if keys is not None:
model_ops.predictions[keys_name] = keys
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
if include_all_in_serving:
# In order to serve the variance we need to add the prediction dict
# to output_alternatives dict.
if not model_ops.output_alternatives:
model_ops.output_alternatives = {}
model_ops.output_alternatives[ALL_SERVING_KEY] = (
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
return model_ops
if params.inference_tree_paths:
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
if include_all_in_serving:
# In order to serve the variance we need to add the prediction dict
# to output_alternatives dict.
if not model_ops.output_alternatives:
model_ops.output_alternatives = {}
model_ops.output_alternatives[ALL_SERVING_KEY] = (
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
return model_ops
else:
# Estimator spec
estimator_spec = model_head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_fn,
logits=logits)
if early_stopping_rounds:
training_hooks.append(
TensorForestLossHook(
early_stopping_rounds,
early_stopping_loss_threshold=early_stopping_loss_threshold,
loss_op=estimator_spec.loss))
estimator_spec = estimator_spec._replace(
training_hooks=training_hooks + list(estimator_spec.training_hooks))
if keys is not None:
estimator_spec.predictions[keys_name] = keys
if params.inference_tree_paths:
estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
if include_all_in_serving:
outputs = estimator_spec.export_outputs
if not outputs:
outputs = {}
outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)}
print(estimator_spec.export_outputs)
# In order to serve the variance we need to add the prediction dict
# to output_alternatives dict.
estimator_spec = estimator_spec._replace(export_outputs=outputs)
return estimator_spec
return _model_fn
@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
params,
graph_builder_class,
device_assigner,
model_head=get_default_head(
params, weight_column, name='head{0}'.format(i)),
model_head=_get_default_head(
params,
weight_column,
name='head{0}'.format(i),
output_type=ModelBuilderOutputType.MODEL_FN_OPS),
weights_name=weight_column,
keys_name=keys_column,
early_stopping_rounds=early_stopping_rounds,
@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
class CoreTensorForestEstimator(core_estimator.Estimator):
"""A CORE estimator that can train and evaluate a random forest.
Example:
```python
params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
# Estimator using the default graph builder.
estimator = CoreTensorForestEstimator(params, model_dir=model_dir)
# Or estimator using TrainingLossForest as the graph builder.
estimator = CoreTensorForestEstimator(
params, graph_builder_class=tensor_forest.TrainingLossForest,
model_dir=model_dir)
# Input builders
def input_fn_train: # returns x, y
...
def input_fn_eval: # returns x, y
...
estimator.train(input_fn=input_fn_train)
estimator.evaluate(input_fn=input_fn_eval)
# Predict returns an iterable of dicts.
results = list(estimator.predict(x=x))
prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
```
"""
def __init__(self,
params,
device_assigner=None,
model_dir=None,
feature_columns=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
config=None,
weight_column=None,
keys_column=None,
feature_engineering_fn=None,
early_stopping_rounds=100,
early_stopping_loss_threshold=0.001,
num_trainers=1,
trainer_id=0,
report_feature_importances=False,
local_eval=False,
version=None,
head=None,
include_all_in_serving=False):
"""Initializes a TensorForestEstimator instance.
Args:
params: ForestHParams object that holds random forest hyperparameters.
These parameters will be passed into `model_fn`.
device_assigner: An `object` instance that controls how trees get
assigned to devices. If `None`, will use
`tensor_forest.RandomForestDeviceAssigner`.
model_dir: Directory to save model parameters, graph, etc. To continue
training a previously saved model, load checkpoints saved to this
directory into an estimator.
feature_columns: An iterable containing all the feature columns used by
the model. All items in the set should be instances of classes derived
from `_FeatureColumn`.
graph_builder_class: An `object` instance that defines how TF graphs for
random forest training and inference are built. By default will use
`tensor_forest.RandomForestGraphs`. Can be overridden by version
kwarg.
config: `RunConfig` object to configure the runtime settings.
weight_column: A string defining feature column name representing
weights. Will be multiplied by the loss of the example. Used to
downweight or boost examples during training.
keys_column: A string naming one of the features to strip out and
pass through into the inference/eval results dict. Useful for
associating specific examples with their prediction.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
early_stopping_rounds: Allows training to terminate early if the forest is
no longer growing. 100 by default. Set to a Falsy value to disable
the default training hook.
early_stopping_loss_threshold: Percentage (as fraction) that loss must
improve by within early_stopping_rounds steps, otherwise training will
terminate.
num_trainers: Number of training jobs, which will partition trees
among them.
trainer_id: Which trainer this instance is.
report_feature_importances: If True, print out feature importances
during evaluation.
local_eval: If True, don't use a device assigner for eval. This is to
support some common setups where eval is done on a single machine, even
though training might be distributed.
version: Unused.
head: A heads_lib.Head object that calculates losses and such. If None,
one will be automatically created based on params.
include_all_in_serving: if True, allow preparation of the complete
prediction dict including the variance to be exported for serving with
the Servo lib; and it also requires calling export_savedmodel with
default_output_alternative_key=ALL_SERVING_KEY, i.e.
estimator.export_savedmodel(export_dir_base=your_export_dir,
serving_input_fn=your_export_input_fn,
default_output_alternative_key=ALL_SERVING_KEY)
if False, resort to default behavior, i.e. export scores and
probabilities but no variances. In this case
default_output_alternative_key should be None while calling
export_savedmodel().
Note, that due to backward compatibility we cannot always set
include_all_in_serving to True because in this case calling
export_saved_model() without
default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
saved_model_export_utils.get_output_alternatives() would raise
ValueError.
Returns:
A `TensorForestEstimator` instance.
"""
super(CoreTensorForestEstimator, self).__init__(
model_fn=get_model_fn(
params.fill(),
graph_builder_class,
device_assigner,
feature_columns=feature_columns,
model_head=head,
weights_name=weight_column,
keys_name=keys_column,
early_stopping_rounds=early_stopping_rounds,
early_stopping_loss_threshold=early_stopping_loss_threshold,
num_trainers=num_trainers,
trainer_id=trainer_id,
report_feature_importances=report_feature_importances,
local_eval=local_eval,
include_all_in_serving=include_all_in_serving,
output_type=ModelBuilderOutputType.ESTIMATOR_SPEC),
model_dir=model_dir,
config=config)

View File

@ -23,7 +23,39 @@ import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_utils
def _get_classification_input_fns():
iris = base.load_iris()
data = iris.data.astype(np.float32)
labels = iris.target.astype(np.int32)
train_input_fn = numpy_io.numpy_input_fn(
x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
return train_input_fn, predict_input_fn
def _get_regression_input_fns():
boston = base.load_boston()
data = boston.data.astype(np.float32)
labels = boston.target.astype(np.int32)
train_input_fn = numpy_io.numpy_input_fn(
x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
return train_input_fn, predict_input_fn
class TensorForestTrainerTests(test.TestCase):
@ -39,18 +71,22 @@ class TensorForestTrainerTests(test.TestCase):
inference_tree_paths=True)
classifier = random_forest.TensorForestEstimator(hparams.fill())
iris = base.load_iris()
data = iris.data.astype(np.float32)
labels = iris.target.astype(np.int32)
input_fn, predict_input_fn = _get_classification_input_fns()
classifier.fit(input_fn=input_fn, steps=100)
res = classifier.evaluate(input_fn=input_fn, steps=10)
classifier.fit(x=data, y=labels, steps=100, batch_size=50)
classifier.evaluate(x=data, y=labels, steps=10)
self.assertEqual(1.0, res['accuracy'])
self.assertAllClose(0.55144483, res['loss'])
predictions = list(classifier.predict(input_fn=predict_input_fn))
self.assertAllClose([[0.576117, 0.211942, 0.211942]],
[pred['probabilities'] for pred in predictions])
def testRegression(self):
"""Tests multi-class classification using matrix data as input."""
"""Tests regression using matrix data as input."""
hparams = tensor_forest.ForestHParams(
num_trees=3,
num_trees=5,
max_nodes=1000,
num_classes=1,
num_features=13,
@ -59,12 +95,261 @@ class TensorForestTrainerTests(test.TestCase):
regressor = random_forest.TensorForestEstimator(hparams.fill())
boston = base.load_boston()
data = boston.data.astype(np.float32)
labels = boston.target.astype(np.int32)
input_fn, predict_input_fn = _get_regression_input_fns()
regressor.fit(x=data, y=labels, steps=100, batch_size=50)
regressor.evaluate(x=data, y=labels, steps=10)
regressor.fit(input_fn=input_fn, steps=100)
res = regressor.evaluate(input_fn=input_fn, steps=10)
self.assertGreaterEqual(0.1, res['loss'])
predictions = list(regressor.predict(input_fn=predict_input_fn))
self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1)
def testAdditionalOutputs(self):
"""Tests multi-class classification using matrix data as input."""
hparams = tensor_forest.ForestHParams(
num_trees=1,
max_nodes=100,
num_classes=3,
num_features=4,
split_after_samples=20,
inference_tree_paths=True)
classifier = random_forest.TensorForestEstimator(
hparams.fill(), keys_column='keys', include_all_in_serving=True)
iris = base.load_iris()
data = iris.data.astype(np.float32)
labels = iris.target.astype(np.int32)
input_fn = numpy_io.numpy_input_fn(
x={
'x': data,
'keys': np.arange(len(iris.data)).reshape(150, 1)
},
y=labels,
batch_size=10,
num_epochs=1,
shuffle=False)
classifier.fit(input_fn=input_fn, steps=100)
predictions = list(classifier.predict(input_fn=input_fn))
# Check that there is a key column, tree paths and var.
for pred in predictions:
self.assertTrue('keys' in pred)
self.assertTrue('tree_paths' in pred)
self.assertTrue('prediction_variance' in pred)
def _assert_checkpoint(self, model_dir, global_step):
reader = checkpoint_utils.load_checkpoint(model_dir)
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
def testEarlyStopping(self):
"""Tests multi-class classification using matrix data as input."""
hparams = tensor_forest.ForestHParams(
num_trees=100,
max_nodes=10000,
num_classes=3,
num_features=4,
split_after_samples=20,
inference_tree_paths=True)
classifier = random_forest.TensorForestEstimator(
hparams.fill(),
# Set a crazy threshold - 30% loss change.
early_stopping_loss_threshold=0.3,
early_stopping_rounds=2)
input_fn, _ = _get_classification_input_fns()
classifier.fit(input_fn=input_fn, steps=100)
# We stopped early.
self._assert_checkpoint(classifier.model_dir, global_step=5)
class CoreTensorForestTests(test.TestCase):
def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self):
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
hparams = tensor_forest.ForestHParams(
num_trees=3,
max_nodes=1000,
num_classes=3,
num_features=4,
split_after_samples=20,
inference_tree_paths=True)
est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn)
input_fn, predict_input_fn = _get_classification_input_fns()
est.train(input_fn=input_fn, steps=100)
res = est.evaluate(input_fn=input_fn, steps=1)
self.assertEqual(1.0, res['accuracy'])
self.assertAllClose(0.55144483, res['loss'])
predictions = list(est.predict(input_fn=predict_input_fn))
self.assertAllClose([[0.576117, 0.211942, 0.211942]],
[pred['probabilities'] for pred in predictions])
def testRegression(self):
"""Tests regression using matrix data as input."""
head_fn = head_lib._regression_head(
label_dimension=1,
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
hparams = tensor_forest.ForestHParams(
num_trees=5,
max_nodes=1000,
num_classes=1,
num_features=13,
regression=True,
split_after_samples=20)
regressor = random_forest.CoreTensorForestEstimator(
hparams.fill(), head=head_fn)
input_fn, predict_input_fn = _get_regression_input_fns()
regressor.train(input_fn=input_fn, steps=100)
res = regressor.evaluate(input_fn=input_fn, steps=10)
self.assertGreaterEqual(0.1, res['loss'])
predictions = list(regressor.predict(input_fn=predict_input_fn))
self.assertAllClose(
[[24.]], [pred['predictions'] for pred in predictions], atol=1)
def testWithFeatureColumns(self):
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
hparams = tensor_forest.ForestHParams(
num_trees=3,
max_nodes=1000,
num_classes=3,
num_features=4,
split_after_samples=20,
inference_tree_paths=True)
est = random_forest.CoreTensorForestEstimator(
hparams.fill(),
head=head_fn,
feature_columns=[core_feature_column.numeric_column('x')])
iris = base.load_iris()
data = {'x': iris.data.astype(np.float32)}
labels = iris.target.astype(np.int32)
input_fn = numpy_io.numpy_input_fn(
x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
est.train(input_fn=input_fn, steps=100)
res = est.evaluate(input_fn=input_fn, steps=1)
self.assertEqual(1.0, res['accuracy'])
self.assertAllClose(0.55144483, res['loss'])
def testAutofillsClassificationHead(self):
hparams = tensor_forest.ForestHParams(
num_trees=3,
max_nodes=1000,
num_classes=3,
num_features=4,
split_after_samples=20,
inference_tree_paths=True)
est = random_forest.CoreTensorForestEstimator(hparams.fill())
input_fn, _ = _get_classification_input_fns()
est.train(input_fn=input_fn, steps=100)
res = est.evaluate(input_fn=input_fn, steps=1)
self.assertEqual(1.0, res['accuracy'])
self.assertAllClose(0.55144483, res['loss'])
def testAutofillsRegressionHead(self):
hparams = tensor_forest.ForestHParams(
num_trees=5,
max_nodes=1000,
num_classes=1,
num_features=13,
regression=True,
split_after_samples=20)
regressor = random_forest.CoreTensorForestEstimator(hparams.fill())
input_fn, predict_input_fn = _get_regression_input_fns()
regressor.train(input_fn=input_fn, steps=100)
res = regressor.evaluate(input_fn=input_fn, steps=10)
self.assertGreaterEqual(0.1, res['loss'])
predictions = list(regressor.predict(input_fn=predict_input_fn))
self.assertAllClose(
[[24.]], [pred['predictions'] for pred in predictions], atol=1)
def testAdditionalOutputs(self):
"""Tests multi-class classification using matrix data as input."""
hparams = tensor_forest.ForestHParams(
num_trees=1,
max_nodes=100,
num_classes=3,
num_features=4,
split_after_samples=20,
inference_tree_paths=True)
classifier = random_forest.CoreTensorForestEstimator(
hparams.fill(), keys_column='keys', include_all_in_serving=True)
iris = base.load_iris()
data = iris.data.astype(np.float32)
labels = iris.target.astype(np.int32)
input_fn = numpy_io.numpy_input_fn(
x={
'x': data,
'keys': np.arange(len(iris.data)).reshape(150, 1)
},
y=labels,
batch_size=10,
num_epochs=1,
shuffle=False)
classifier.train(input_fn=input_fn, steps=100)
predictions = list(classifier.predict(input_fn=input_fn))
# Check that there is a key column, tree paths and var.
for pred in predictions:
self.assertTrue('keys' in pred)
self.assertTrue('tree_paths' in pred)
self.assertTrue('prediction_variance' in pred)
def _assert_checkpoint(self, model_dir, global_step):
reader = checkpoint_utils.load_checkpoint(model_dir)
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
def testEarlyStopping(self):
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
hparams = tensor_forest.ForestHParams(
num_trees=3,
max_nodes=1000,
num_classes=3,
num_features=4,
split_after_samples=20,
inference_tree_paths=True)
est = random_forest.CoreTensorForestEstimator(
hparams.fill(),
head=head_fn,
# Set a crazy threshold - 30% loss change.
early_stopping_loss_threshold=0.3,
early_stopping_rounds=2)
input_fn, _ = _get_classification_input_fns()
est.train(input_fn=input_fn, steps=100)
# We stopped early.
self._assert_checkpoint(est.model_dir, global_step=5)
if __name__ == "__main__":