diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 35e8c92aba3..8fa0b3ada94 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib - from tensorflow.contrib.tensor_forest.client import eval_metrics from tensorflow.contrib.tensor_forest.python import tensor_forest - +from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator.canned import head as core_head_lib +from tensorflow.python.estimator.export.export_output import PredictOutput +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util - KEYS_NAME = 'keys' LOSS_NAME = 'rf_training_loss' TREE_PATHS_PREDICTION_KEY = 'tree_paths' @@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all' EPSILON = 0.000001 +class ModelBuilderOutputType(object): + MODEL_FN_OPS = 0 + ESTIMATOR_SPEC = 1 + + class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook): def __init__(self, op_dict): @@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook): run_context.request_stop() -def get_default_head(params, weights_name, name=None): - if params.regression: - return head_lib.regression_head( - weight_column_name=weights_name, - label_dimension=params.num_outputs, - enable_centered_bias=False, - head_name=name) +def _get_default_head(params, weights_name, output_type, name=None): + """Creates a default head based on a type of a problem.""" + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if params.regression: + return head_lib.regression_head( + weight_column_name=weights_name, + label_dimension=params.num_outputs, + enable_centered_bias=False, + head_name=name) + else: + return head_lib.multi_class_head( + params.num_classes, + weight_column_name=weights_name, + enable_centered_bias=False, + head_name=name) else: - return head_lib.multi_class_head( - params.num_classes, - weight_column_name=weights_name, - enable_centered_bias=False, - head_name=name) - + if params.regression: + return core_head_lib._regression_head( # pylint:disable=protected-access + weight_column=weights_name, + label_dimension=params.num_outputs, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + else: + return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access + n_classes=params.num_classes, + weight_column=weights_name, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) def get_model_fn(params, graph_builder_class, @@ -135,19 +156,27 @@ def get_model_fn(params, report_feature_importances=False, local_eval=False, head_scope=None, - include_all_in_serving=False): + include_all_in_serving=False, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Return a model function given a way to construct a graph builder.""" if model_head is None: - model_head = get_default_head(params, weights_name) + model_head = _get_default_head(params, weights_name, output_type) def _model_fn(features, labels, mode): """Function that returns predictions, training loss, and training op.""" + if (isinstance(features, ops.Tensor) or isinstance(features, sparse_tensor.SparseTensor)): features = {'features': features} if feature_columns: features = features.copy() - features.update(layers.transform_features(features, feature_columns)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + features.update(layers.transform_features(features, feature_columns)) + else: + for fc in feature_columns: + tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access + features[fc.name] = tensor weights = None if weights_name and weights_name in features: @@ -201,52 +230,95 @@ def get_model_fn(params, def _train_fn(unused_loss): return training_graph - model_ops = model_head.create_model_fn_ops( - features=features, - labels=labels, - mode=mode, - train_op_fn=_train_fn, - logits=logits, - scope=head_scope) # Ops are run in lexigraphical order of their keys. Run the resource # clean-up op last. all_handles = graph_builder.get_all_resource_handles() ops_at_end = { - '9: clean up resources': control_flow_ops.group( - *[resource_variable_ops.destroy_resource_op(handle) - for handle in all_handles])} + '9: clean up resources': + control_flow_ops.group(*[ + resource_variable_ops.destroy_resource_op(handle) + for handle in all_handles + ]) + } if report_feature_importances: ops_at_end['1: feature_importances'] = ( graph_builder.feature_importances()) - training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end)) + training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)] - if early_stopping_rounds: - training_hooks.append( - TensorForestLossHook( - early_stopping_rounds, - early_stopping_loss_threshold=early_stopping_loss_threshold, - loss_op=model_ops.loss)) + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + model_ops = model_head.create_model_fn_ops( + features=features, + labels=labels, + mode=mode, + train_op_fn=_train_fn, + logits=logits, + scope=head_scope) - model_ops.training_hooks.extend(training_hooks) + if early_stopping_rounds: + training_hooks.append( + TensorForestLossHook( + early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + loss_op=model_ops.loss)) - if keys is not None: - model_ops.predictions[keys_name] = keys + model_ops.training_hooks.extend(training_hooks) - if params.inference_tree_paths: - model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + if keys is not None: + model_ops.predictions[keys_name] = keys - model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance - if include_all_in_serving: - # In order to serve the variance we need to add the prediction dict - # to output_alternatives dict. - if not model_ops.output_alternatives: - model_ops.output_alternatives = {} - model_ops.output_alternatives[ALL_SERVING_KEY] = ( - constants.ProblemType.UNSPECIFIED, model_ops.predictions) - return model_ops + if params.inference_tree_paths: + model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + + model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance + + if include_all_in_serving: + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + if not model_ops.output_alternatives: + model_ops.output_alternatives = {} + model_ops.output_alternatives[ALL_SERVING_KEY] = ( + constants.ProblemType.UNSPECIFIED, model_ops.predictions) + + return model_ops + + else: + # Estimator spec + estimator_spec = model_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_fn, + logits=logits) + + if early_stopping_rounds: + training_hooks.append( + TensorForestLossHook( + early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + loss_op=estimator_spec.loss)) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + if keys is not None: + estimator_spec.predictions[keys_name] = keys + if params.inference_tree_paths: + estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance + + if include_all_in_serving: + outputs = estimator_spec.export_outputs + if not outputs: + outputs = {} + outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)} + print(estimator_spec.export_outputs) + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + estimator_spec = estimator_spec._replace(export_outputs=outputs) + + return estimator_spec return _model_fn @@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): params, graph_builder_class, device_assigner, - model_head=get_default_head( - params, weight_column, name='head{0}'.format(i)), + model_head=_get_default_head( + params, + weight_column, + name='head{0}'.format(i), + output_type=ModelBuilderOutputType.MODEL_FN_OPS), weights_name=weight_column, keys_name=keys_column, early_stopping_rounds=early_stopping_rounds, @@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreTensorForestEstimator(core_estimator.Estimator): + """A CORE estimator that can train and evaluate a random forest. + + Example: + + ```python + params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( + num_classes=2, num_features=40, num_trees=10, max_nodes=1000) + + # Estimator using the default graph builder. + estimator = CoreTensorForestEstimator(params, model_dir=model_dir) + + # Or estimator using TrainingLossForest as the graph builder. + estimator = CoreTensorForestEstimator( + params, graph_builder_class=tensor_forest.TrainingLossForest, + model_dir=model_dir) + + # Input builders + def input_fn_train: # returns x, y + ... + def input_fn_eval: # returns x, y + ... + estimator.train(input_fn=input_fn_train) + estimator.evaluate(input_fn=input_fn_eval) + + # Predict returns an iterable of dicts. + results = list(estimator.predict(x=x)) + prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME] + prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME] + ``` + """ + + def __init__(self, + params, + device_assigner=None, + model_dir=None, + feature_columns=None, + graph_builder_class=tensor_forest.RandomForestGraphs, + config=None, + weight_column=None, + keys_column=None, + feature_engineering_fn=None, + early_stopping_rounds=100, + early_stopping_loss_threshold=0.001, + num_trainers=1, + trainer_id=0, + report_feature_importances=False, + local_eval=False, + version=None, + head=None, + include_all_in_serving=False): + """Initializes a TensorForestEstimator instance. + + Args: + params: ForestHParams object that holds random forest hyperparameters. + These parameters will be passed into `model_fn`. + device_assigner: An `object` instance that controls how trees get + assigned to devices. If `None`, will use + `tensor_forest.RandomForestDeviceAssigner`. + model_dir: Directory to save model parameters, graph, etc. To continue + training a previously saved model, load checkpoints saved to this + directory into an estimator. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `_FeatureColumn`. + graph_builder_class: An `object` instance that defines how TF graphs for + random forest training and inference are built. By default will use + `tensor_forest.RandomForestGraphs`. Can be overridden by version + kwarg. + config: `RunConfig` object to configure the runtime settings. + weight_column: A string defining feature column name representing + weights. Will be multiplied by the loss of the example. Used to + downweight or boost examples during training. + keys_column: A string naming one of the features to strip out and + pass through into the inference/eval results dict. Useful for + associating specific examples with their prediction. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + early_stopping_rounds: Allows training to terminate early if the forest is + no longer growing. 100 by default. Set to a Falsy value to disable + the default training hook. + early_stopping_loss_threshold: Percentage (as fraction) that loss must + improve by within early_stopping_rounds steps, otherwise training will + terminate. + num_trainers: Number of training jobs, which will partition trees + among them. + trainer_id: Which trainer this instance is. + report_feature_importances: If True, print out feature importances + during evaluation. + local_eval: If True, don't use a device assigner for eval. This is to + support some common setups where eval is done on a single machine, even + though training might be distributed. + version: Unused. + head: A heads_lib.Head object that calculates losses and such. If None, + one will be automatically created based on params. + include_all_in_serving: if True, allow preparation of the complete + prediction dict including the variance to be exported for serving with + the Servo lib; and it also requires calling export_savedmodel with + default_output_alternative_key=ALL_SERVING_KEY, i.e. + estimator.export_savedmodel(export_dir_base=your_export_dir, + serving_input_fn=your_export_input_fn, + default_output_alternative_key=ALL_SERVING_KEY) + if False, resort to default behavior, i.e. export scores and + probabilities but no variances. In this case + default_output_alternative_key should be None while calling + export_savedmodel(). + Note, that due to backward compatibility we cannot always set + include_all_in_serving to True because in this case calling + export_saved_model() without + default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the + saved_model_export_utils.get_output_alternatives() would raise + ValueError. + + Returns: + A `TensorForestEstimator` instance. + """ + + super(CoreTensorForestEstimator, self).__init__( + model_fn=get_model_fn( + params.fill(), + graph_builder_class, + device_assigner, + feature_columns=feature_columns, + model_head=head, + weights_name=weight_column, + keys_name=keys_column, + early_stopping_rounds=early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + num_trainers=num_trainers, + trainer_id=trainer_id, + report_feature_importances=report_feature_importances, + local_eval=local_eval, + include_all_in_serving=include_all_in_serving, + output_type=ModelBuilderOutputType.ESTIMATOR_SPEC), + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/tensor_forest/client/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py index ac42364d257..e951592f853 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest_test.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py @@ -23,7 +23,39 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.tensor_forest.client import random_forest from tensorflow.contrib.tensor_forest.python import tensor_forest +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column_lib as core_feature_column +from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_utils + + +def _get_classification_input_fns(): + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + train_input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False) + + predict_input_fn = numpy_io.numpy_input_fn( + x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False) + return train_input_fn, predict_input_fn + + +def _get_regression_input_fns(): + boston = base.load_boston() + data = boston.data.astype(np.float32) + labels = boston.target.astype(np.int32) + + train_input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False) + + predict_input_fn = numpy_io.numpy_input_fn( + x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False) + return train_input_fn, predict_input_fn class TensorForestTrainerTests(test.TestCase): @@ -39,18 +71,22 @@ class TensorForestTrainerTests(test.TestCase): inference_tree_paths=True) classifier = random_forest.TensorForestEstimator(hparams.fill()) - iris = base.load_iris() - data = iris.data.astype(np.float32) - labels = iris.target.astype(np.int32) + input_fn, predict_input_fn = _get_classification_input_fns() + classifier.fit(input_fn=input_fn, steps=100) + res = classifier.evaluate(input_fn=input_fn, steps=10) - classifier.fit(x=data, y=labels, steps=100, batch_size=50) - classifier.evaluate(x=data, y=labels, steps=10) + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + predictions = list(classifier.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0.576117, 0.211942, 0.211942]], + [pred['probabilities'] for pred in predictions]) def testRegression(self): - """Tests multi-class classification using matrix data as input.""" + """Tests regression using matrix data as input.""" hparams = tensor_forest.ForestHParams( - num_trees=3, + num_trees=5, max_nodes=1000, num_classes=1, num_features=13, @@ -59,12 +95,261 @@ class TensorForestTrainerTests(test.TestCase): regressor = random_forest.TensorForestEstimator(hparams.fill()) - boston = base.load_boston() - data = boston.data.astype(np.float32) - labels = boston.target.astype(np.int32) + input_fn, predict_input_fn = _get_regression_input_fns() - regressor.fit(x=data, y=labels, steps=100, batch_size=50) - regressor.evaluate(x=data, y=labels, steps=10) + regressor.fit(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1) + + def testAdditionalOutputs(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=1, + max_nodes=100, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.TensorForestEstimator( + hparams.fill(), keys_column='keys', include_all_in_serving=True) + + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'x': data, + 'keys': np.arange(len(iris.data)).reshape(150, 1) + }, + y=labels, + batch_size=10, + num_epochs=1, + shuffle=False) + + classifier.fit(input_fn=input_fn, steps=100) + predictions = list(classifier.predict(input_fn=input_fn)) + # Check that there is a key column, tree paths and var. + for pred in predictions: + self.assertTrue('keys' in pred) + self.assertTrue('tree_paths' in pred) + self.assertTrue('prediction_variance' in pred) + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testEarlyStopping(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=100, + max_nodes=10000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.TensorForestEstimator( + hparams.fill(), + # Set a crazy threshold - 30% loss change. + early_stopping_loss_threshold=0.3, + early_stopping_rounds=2) + + input_fn, _ = _get_classification_input_fns() + classifier.fit(input_fn=input_fn, steps=100) + + # We stopped early. + self._assert_checkpoint(classifier.model_dir, global_step=5) + + +class CoreTensorForestTests(test.TestCase): + + def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn) + + input_fn, predict_input_fn = _get_classification_input_fns() + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0.576117, 0.211942, 0.211942]], + [pred['probabilities'] for pred in predictions]) + + def testRegression(self): + """Tests regression using matrix data as input.""" + head_fn = head_lib._regression_head( + label_dimension=1, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=5, + max_nodes=1000, + num_classes=1, + num_features=13, + regression=True, + split_after_samples=20) + + regressor = random_forest.CoreTensorForestEstimator( + hparams.fill(), head=head_fn) + + input_fn, predict_input_fn = _get_regression_input_fns() + + regressor.train(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[24.]], [pred['predictions'] for pred in predictions], atol=1) + + def testWithFeatureColumns(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator( + hparams.fill(), + head=head_fn, + feature_columns=[core_feature_column.numeric_column('x')]) + + iris = base.load_iris() + data = {'x': iris.data.astype(np.float32)} + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False) + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + def testAutofillsClassificationHead(self): + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator(hparams.fill()) + + input_fn, _ = _get_classification_input_fns() + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + def testAutofillsRegressionHead(self): + hparams = tensor_forest.ForestHParams( + num_trees=5, + max_nodes=1000, + num_classes=1, + num_features=13, + regression=True, + split_after_samples=20) + + regressor = random_forest.CoreTensorForestEstimator(hparams.fill()) + + input_fn, predict_input_fn = _get_regression_input_fns() + + regressor.train(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[24.]], [pred['predictions'] for pred in predictions], atol=1) + + def testAdditionalOutputs(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=1, + max_nodes=100, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.CoreTensorForestEstimator( + hparams.fill(), keys_column='keys', include_all_in_serving=True) + + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'x': data, + 'keys': np.arange(len(iris.data)).reshape(150, 1) + }, + y=labels, + batch_size=10, + num_epochs=1, + shuffle=False) + + classifier.train(input_fn=input_fn, steps=100) + predictions = list(classifier.predict(input_fn=input_fn)) + # Check that there is a key column, tree paths and var. + for pred in predictions: + self.assertTrue('keys' in pred) + self.assertTrue('tree_paths' in pred) + self.assertTrue('prediction_variance' in pred) + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testEarlyStopping(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator( + hparams.fill(), + head=head_fn, + # Set a crazy threshold - 30% loss change. + early_stopping_loss_threshold=0.3, + early_stopping_rounds=2) + + input_fn, _ = _get_classification_input_fns() + est.train(input_fn=input_fn, steps=100) + # We stopped early. + self._assert_checkpoint(est.model_dir, global_step=5) if __name__ == "__main__":