Providing a core estimator interface over a contrib tensorforest.
PiperOrigin-RevId: 208658097
This commit is contained in:
parent
d7f93284c8
commit
b03f732b3f
@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
||||
|
||||
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||
|
||||
from tensorflow.python.estimator import estimator as core_estimator
|
||||
from tensorflow.python.estimator.canned import head as core_head_lib
|
||||
from tensorflow.python.estimator.export.export_output import PredictOutput
|
||||
from tensorflow.python.feature_column import feature_column as fc_core
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
KEYS_NAME = 'keys'
|
||||
LOSS_NAME = 'rf_training_loss'
|
||||
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
|
||||
@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all'
|
||||
EPSILON = 0.000001
|
||||
|
||||
|
||||
class ModelBuilderOutputType(object):
|
||||
MODEL_FN_OPS = 0
|
||||
ESTIMATOR_SPEC = 1
|
||||
|
||||
|
||||
class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
|
||||
|
||||
def __init__(self, op_dict):
|
||||
@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
|
||||
run_context.request_stop()
|
||||
|
||||
|
||||
def get_default_head(params, weights_name, name=None):
|
||||
if params.regression:
|
||||
return head_lib.regression_head(
|
||||
weight_column_name=weights_name,
|
||||
label_dimension=params.num_outputs,
|
||||
enable_centered_bias=False,
|
||||
head_name=name)
|
||||
def _get_default_head(params, weights_name, output_type, name=None):
|
||||
"""Creates a default head based on a type of a problem."""
|
||||
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||
if params.regression:
|
||||
return head_lib.regression_head(
|
||||
weight_column_name=weights_name,
|
||||
label_dimension=params.num_outputs,
|
||||
enable_centered_bias=False,
|
||||
head_name=name)
|
||||
else:
|
||||
return head_lib.multi_class_head(
|
||||
params.num_classes,
|
||||
weight_column_name=weights_name,
|
||||
enable_centered_bias=False,
|
||||
head_name=name)
|
||||
else:
|
||||
return head_lib.multi_class_head(
|
||||
params.num_classes,
|
||||
weight_column_name=weights_name,
|
||||
enable_centered_bias=False,
|
||||
head_name=name)
|
||||
|
||||
if params.regression:
|
||||
return core_head_lib._regression_head( # pylint:disable=protected-access
|
||||
weight_column=weights_name,
|
||||
label_dimension=params.num_outputs,
|
||||
name=name,
|
||||
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
else:
|
||||
return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
|
||||
n_classes=params.num_classes,
|
||||
weight_column=weights_name,
|
||||
name=name,
|
||||
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
def get_model_fn(params,
|
||||
graph_builder_class,
|
||||
@ -135,19 +156,27 @@ def get_model_fn(params,
|
||||
report_feature_importances=False,
|
||||
local_eval=False,
|
||||
head_scope=None,
|
||||
include_all_in_serving=False):
|
||||
include_all_in_serving=False,
|
||||
output_type=ModelBuilderOutputType.MODEL_FN_OPS):
|
||||
"""Return a model function given a way to construct a graph builder."""
|
||||
if model_head is None:
|
||||
model_head = get_default_head(params, weights_name)
|
||||
model_head = _get_default_head(params, weights_name, output_type)
|
||||
|
||||
def _model_fn(features, labels, mode):
|
||||
"""Function that returns predictions, training loss, and training op."""
|
||||
|
||||
if (isinstance(features, ops.Tensor) or
|
||||
isinstance(features, sparse_tensor.SparseTensor)):
|
||||
features = {'features': features}
|
||||
if feature_columns:
|
||||
features = features.copy()
|
||||
features.update(layers.transform_features(features, feature_columns))
|
||||
|
||||
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||
features.update(layers.transform_features(features, feature_columns))
|
||||
else:
|
||||
for fc in feature_columns:
|
||||
tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access
|
||||
features[fc.name] = tensor
|
||||
|
||||
weights = None
|
||||
if weights_name and weights_name in features:
|
||||
@ -201,52 +230,95 @@ def get_model_fn(params,
|
||||
def _train_fn(unused_loss):
|
||||
return training_graph
|
||||
|
||||
model_ops = model_head.create_model_fn_ops(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
train_op_fn=_train_fn,
|
||||
logits=logits,
|
||||
scope=head_scope)
|
||||
|
||||
# Ops are run in lexigraphical order of their keys. Run the resource
|
||||
# clean-up op last.
|
||||
all_handles = graph_builder.get_all_resource_handles()
|
||||
ops_at_end = {
|
||||
'9: clean up resources': control_flow_ops.group(
|
||||
*[resource_variable_ops.destroy_resource_op(handle)
|
||||
for handle in all_handles])}
|
||||
'9: clean up resources':
|
||||
control_flow_ops.group(*[
|
||||
resource_variable_ops.destroy_resource_op(handle)
|
||||
for handle in all_handles
|
||||
])
|
||||
}
|
||||
|
||||
if report_feature_importances:
|
||||
ops_at_end['1: feature_importances'] = (
|
||||
graph_builder.feature_importances())
|
||||
|
||||
training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
|
||||
training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)]
|
||||
|
||||
if early_stopping_rounds:
|
||||
training_hooks.append(
|
||||
TensorForestLossHook(
|
||||
early_stopping_rounds,
|
||||
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
||||
loss_op=model_ops.loss))
|
||||
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||
model_ops = model_head.create_model_fn_ops(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
train_op_fn=_train_fn,
|
||||
logits=logits,
|
||||
scope=head_scope)
|
||||
|
||||
model_ops.training_hooks.extend(training_hooks)
|
||||
if early_stopping_rounds:
|
||||
training_hooks.append(
|
||||
TensorForestLossHook(
|
||||
early_stopping_rounds,
|
||||
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
||||
loss_op=model_ops.loss))
|
||||
|
||||
if keys is not None:
|
||||
model_ops.predictions[keys_name] = keys
|
||||
model_ops.training_hooks.extend(training_hooks)
|
||||
|
||||
if params.inference_tree_paths:
|
||||
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||
if keys is not None:
|
||||
model_ops.predictions[keys_name] = keys
|
||||
|
||||
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||
if include_all_in_serving:
|
||||
# In order to serve the variance we need to add the prediction dict
|
||||
# to output_alternatives dict.
|
||||
if not model_ops.output_alternatives:
|
||||
model_ops.output_alternatives = {}
|
||||
model_ops.output_alternatives[ALL_SERVING_KEY] = (
|
||||
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
|
||||
return model_ops
|
||||
if params.inference_tree_paths:
|
||||
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||
|
||||
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||
|
||||
if include_all_in_serving:
|
||||
# In order to serve the variance we need to add the prediction dict
|
||||
# to output_alternatives dict.
|
||||
if not model_ops.output_alternatives:
|
||||
model_ops.output_alternatives = {}
|
||||
model_ops.output_alternatives[ALL_SERVING_KEY] = (
|
||||
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
|
||||
|
||||
return model_ops
|
||||
|
||||
else:
|
||||
# Estimator spec
|
||||
estimator_spec = model_head.create_estimator_spec(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_train_fn,
|
||||
logits=logits)
|
||||
|
||||
if early_stopping_rounds:
|
||||
training_hooks.append(
|
||||
TensorForestLossHook(
|
||||
early_stopping_rounds,
|
||||
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
||||
loss_op=estimator_spec.loss))
|
||||
|
||||
estimator_spec = estimator_spec._replace(
|
||||
training_hooks=training_hooks + list(estimator_spec.training_hooks))
|
||||
if keys is not None:
|
||||
estimator_spec.predictions[keys_name] = keys
|
||||
if params.inference_tree_paths:
|
||||
estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||
estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||
|
||||
if include_all_in_serving:
|
||||
outputs = estimator_spec.export_outputs
|
||||
if not outputs:
|
||||
outputs = {}
|
||||
outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)}
|
||||
print(estimator_spec.export_outputs)
|
||||
# In order to serve the variance we need to add the prediction dict
|
||||
# to output_alternatives dict.
|
||||
estimator_spec = estimator_spec._replace(export_outputs=outputs)
|
||||
|
||||
return estimator_spec
|
||||
|
||||
return _model_fn
|
||||
|
||||
@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
|
||||
params,
|
||||
graph_builder_class,
|
||||
device_assigner,
|
||||
model_head=get_default_head(
|
||||
params, weight_column, name='head{0}'.format(i)),
|
||||
model_head=_get_default_head(
|
||||
params,
|
||||
weight_column,
|
||||
name='head{0}'.format(i),
|
||||
output_type=ModelBuilderOutputType.MODEL_FN_OPS),
|
||||
weights_name=weight_column,
|
||||
keys_name=keys_column,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
class CoreTensorForestEstimator(core_estimator.Estimator):
|
||||
"""A CORE estimator that can train and evaluate a random forest.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||
num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
|
||||
|
||||
# Estimator using the default graph builder.
|
||||
estimator = CoreTensorForestEstimator(params, model_dir=model_dir)
|
||||
|
||||
# Or estimator using TrainingLossForest as the graph builder.
|
||||
estimator = CoreTensorForestEstimator(
|
||||
params, graph_builder_class=tensor_forest.TrainingLossForest,
|
||||
model_dir=model_dir)
|
||||
|
||||
# Input builders
|
||||
def input_fn_train: # returns x, y
|
||||
...
|
||||
def input_fn_eval: # returns x, y
|
||||
...
|
||||
estimator.train(input_fn=input_fn_train)
|
||||
estimator.evaluate(input_fn=input_fn_eval)
|
||||
|
||||
# Predict returns an iterable of dicts.
|
||||
results = list(estimator.predict(x=x))
|
||||
prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
|
||||
prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
device_assigner=None,
|
||||
model_dir=None,
|
||||
feature_columns=None,
|
||||
graph_builder_class=tensor_forest.RandomForestGraphs,
|
||||
config=None,
|
||||
weight_column=None,
|
||||
keys_column=None,
|
||||
feature_engineering_fn=None,
|
||||
early_stopping_rounds=100,
|
||||
early_stopping_loss_threshold=0.001,
|
||||
num_trainers=1,
|
||||
trainer_id=0,
|
||||
report_feature_importances=False,
|
||||
local_eval=False,
|
||||
version=None,
|
||||
head=None,
|
||||
include_all_in_serving=False):
|
||||
"""Initializes a TensorForestEstimator instance.
|
||||
|
||||
Args:
|
||||
params: ForestHParams object that holds random forest hyperparameters.
|
||||
These parameters will be passed into `model_fn`.
|
||||
device_assigner: An `object` instance that controls how trees get
|
||||
assigned to devices. If `None`, will use
|
||||
`tensor_forest.RandomForestDeviceAssigner`.
|
||||
model_dir: Directory to save model parameters, graph, etc. To continue
|
||||
training a previously saved model, load checkpoints saved to this
|
||||
directory into an estimator.
|
||||
feature_columns: An iterable containing all the feature columns used by
|
||||
the model. All items in the set should be instances of classes derived
|
||||
from `_FeatureColumn`.
|
||||
graph_builder_class: An `object` instance that defines how TF graphs for
|
||||
random forest training and inference are built. By default will use
|
||||
`tensor_forest.RandomForestGraphs`. Can be overridden by version
|
||||
kwarg.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
weight_column: A string defining feature column name representing
|
||||
weights. Will be multiplied by the loss of the example. Used to
|
||||
downweight or boost examples during training.
|
||||
keys_column: A string naming one of the features to strip out and
|
||||
pass through into the inference/eval results dict. Useful for
|
||||
associating specific examples with their prediction.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
early_stopping_rounds: Allows training to terminate early if the forest is
|
||||
no longer growing. 100 by default. Set to a Falsy value to disable
|
||||
the default training hook.
|
||||
early_stopping_loss_threshold: Percentage (as fraction) that loss must
|
||||
improve by within early_stopping_rounds steps, otherwise training will
|
||||
terminate.
|
||||
num_trainers: Number of training jobs, which will partition trees
|
||||
among them.
|
||||
trainer_id: Which trainer this instance is.
|
||||
report_feature_importances: If True, print out feature importances
|
||||
during evaluation.
|
||||
local_eval: If True, don't use a device assigner for eval. This is to
|
||||
support some common setups where eval is done on a single machine, even
|
||||
though training might be distributed.
|
||||
version: Unused.
|
||||
head: A heads_lib.Head object that calculates losses and such. If None,
|
||||
one will be automatically created based on params.
|
||||
include_all_in_serving: if True, allow preparation of the complete
|
||||
prediction dict including the variance to be exported for serving with
|
||||
the Servo lib; and it also requires calling export_savedmodel with
|
||||
default_output_alternative_key=ALL_SERVING_KEY, i.e.
|
||||
estimator.export_savedmodel(export_dir_base=your_export_dir,
|
||||
serving_input_fn=your_export_input_fn,
|
||||
default_output_alternative_key=ALL_SERVING_KEY)
|
||||
if False, resort to default behavior, i.e. export scores and
|
||||
probabilities but no variances. In this case
|
||||
default_output_alternative_key should be None while calling
|
||||
export_savedmodel().
|
||||
Note, that due to backward compatibility we cannot always set
|
||||
include_all_in_serving to True because in this case calling
|
||||
export_saved_model() without
|
||||
default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
|
||||
saved_model_export_utils.get_output_alternatives() would raise
|
||||
ValueError.
|
||||
|
||||
Returns:
|
||||
A `TensorForestEstimator` instance.
|
||||
"""
|
||||
|
||||
super(CoreTensorForestEstimator, self).__init__(
|
||||
model_fn=get_model_fn(
|
||||
params.fill(),
|
||||
graph_builder_class,
|
||||
device_assigner,
|
||||
feature_columns=feature_columns,
|
||||
model_head=head,
|
||||
weights_name=weight_column,
|
||||
keys_name=keys_column,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
||||
num_trainers=num_trainers,
|
||||
trainer_id=trainer_id,
|
||||
report_feature_importances=report_feature_importances,
|
||||
local_eval=local_eval,
|
||||
include_all_in_serving=include_all_in_serving,
|
||||
output_type=ModelBuilderOutputType.ESTIMATOR_SPEC),
|
||||
model_dir=model_dir,
|
||||
config=config)
|
||||
|
@ -23,7 +23,39 @@ import numpy as np
|
||||
from tensorflow.contrib.learn.python.learn.datasets import base
|
||||
from tensorflow.contrib.tensor_forest.client import random_forest
|
||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||
from tensorflow.python.estimator.canned import head as head_lib
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
|
||||
|
||||
def _get_classification_input_fns():
|
||||
iris = base.load_iris()
|
||||
data = iris.data.astype(np.float32)
|
||||
labels = iris.target.astype(np.int32)
|
||||
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
|
||||
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
|
||||
return train_input_fn, predict_input_fn
|
||||
|
||||
|
||||
def _get_regression_input_fns():
|
||||
boston = base.load_boston()
|
||||
data = boston.data.astype(np.float32)
|
||||
labels = boston.target.astype(np.int32)
|
||||
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False)
|
||||
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
|
||||
return train_input_fn, predict_input_fn
|
||||
|
||||
|
||||
class TensorForestTrainerTests(test.TestCase):
|
||||
@ -39,18 +71,22 @@ class TensorForestTrainerTests(test.TestCase):
|
||||
inference_tree_paths=True)
|
||||
classifier = random_forest.TensorForestEstimator(hparams.fill())
|
||||
|
||||
iris = base.load_iris()
|
||||
data = iris.data.astype(np.float32)
|
||||
labels = iris.target.astype(np.int32)
|
||||
input_fn, predict_input_fn = _get_classification_input_fns()
|
||||
classifier.fit(input_fn=input_fn, steps=100)
|
||||
res = classifier.evaluate(input_fn=input_fn, steps=10)
|
||||
|
||||
classifier.fit(x=data, y=labels, steps=100, batch_size=50)
|
||||
classifier.evaluate(x=data, y=labels, steps=10)
|
||||
self.assertEqual(1.0, res['accuracy'])
|
||||
self.assertAllClose(0.55144483, res['loss'])
|
||||
|
||||
predictions = list(classifier.predict(input_fn=predict_input_fn))
|
||||
self.assertAllClose([[0.576117, 0.211942, 0.211942]],
|
||||
[pred['probabilities'] for pred in predictions])
|
||||
|
||||
def testRegression(self):
|
||||
"""Tests multi-class classification using matrix data as input."""
|
||||
"""Tests regression using matrix data as input."""
|
||||
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=3,
|
||||
num_trees=5,
|
||||
max_nodes=1000,
|
||||
num_classes=1,
|
||||
num_features=13,
|
||||
@ -59,12 +95,261 @@ class TensorForestTrainerTests(test.TestCase):
|
||||
|
||||
regressor = random_forest.TensorForestEstimator(hparams.fill())
|
||||
|
||||
boston = base.load_boston()
|
||||
data = boston.data.astype(np.float32)
|
||||
labels = boston.target.astype(np.int32)
|
||||
input_fn, predict_input_fn = _get_regression_input_fns()
|
||||
|
||||
regressor.fit(x=data, y=labels, steps=100, batch_size=50)
|
||||
regressor.evaluate(x=data, y=labels, steps=10)
|
||||
regressor.fit(input_fn=input_fn, steps=100)
|
||||
res = regressor.evaluate(input_fn=input_fn, steps=10)
|
||||
self.assertGreaterEqual(0.1, res['loss'])
|
||||
|
||||
predictions = list(regressor.predict(input_fn=predict_input_fn))
|
||||
self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1)
|
||||
|
||||
def testAdditionalOutputs(self):
|
||||
"""Tests multi-class classification using matrix data as input."""
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=1,
|
||||
max_nodes=100,
|
||||
num_classes=3,
|
||||
num_features=4,
|
||||
split_after_samples=20,
|
||||
inference_tree_paths=True)
|
||||
classifier = random_forest.TensorForestEstimator(
|
||||
hparams.fill(), keys_column='keys', include_all_in_serving=True)
|
||||
|
||||
iris = base.load_iris()
|
||||
data = iris.data.astype(np.float32)
|
||||
labels = iris.target.astype(np.int32)
|
||||
|
||||
input_fn = numpy_io.numpy_input_fn(
|
||||
x={
|
||||
'x': data,
|
||||
'keys': np.arange(len(iris.data)).reshape(150, 1)
|
||||
},
|
||||
y=labels,
|
||||
batch_size=10,
|
||||
num_epochs=1,
|
||||
shuffle=False)
|
||||
|
||||
classifier.fit(input_fn=input_fn, steps=100)
|
||||
predictions = list(classifier.predict(input_fn=input_fn))
|
||||
# Check that there is a key column, tree paths and var.
|
||||
for pred in predictions:
|
||||
self.assertTrue('keys' in pred)
|
||||
self.assertTrue('tree_paths' in pred)
|
||||
self.assertTrue('prediction_variance' in pred)
|
||||
|
||||
def _assert_checkpoint(self, model_dir, global_step):
|
||||
reader = checkpoint_utils.load_checkpoint(model_dir)
|
||||
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
|
||||
|
||||
def testEarlyStopping(self):
|
||||
"""Tests multi-class classification using matrix data as input."""
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=100,
|
||||
max_nodes=10000,
|
||||
num_classes=3,
|
||||
num_features=4,
|
||||
split_after_samples=20,
|
||||
inference_tree_paths=True)
|
||||
classifier = random_forest.TensorForestEstimator(
|
||||
hparams.fill(),
|
||||
# Set a crazy threshold - 30% loss change.
|
||||
early_stopping_loss_threshold=0.3,
|
||||
early_stopping_rounds=2)
|
||||
|
||||
input_fn, _ = _get_classification_input_fns()
|
||||
classifier.fit(input_fn=input_fn, steps=100)
|
||||
|
||||
# We stopped early.
|
||||
self._assert_checkpoint(classifier.model_dir, global_step=5)
|
||||
|
||||
|
||||
class CoreTensorForestTests(test.TestCase):
|
||||
|
||||
def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self):
|
||||
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=3,
|
||||
max_nodes=1000,
|
||||
num_classes=3,
|
||||
num_features=4,
|
||||
split_after_samples=20,
|
||||
inference_tree_paths=True)
|
||||
|
||||
est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn)
|
||||
|
||||
input_fn, predict_input_fn = _get_classification_input_fns()
|
||||
|
||||
est.train(input_fn=input_fn, steps=100)
|
||||
res = est.evaluate(input_fn=input_fn, steps=1)
|
||||
|
||||
self.assertEqual(1.0, res['accuracy'])
|
||||
self.assertAllClose(0.55144483, res['loss'])
|
||||
|
||||
predictions = list(est.predict(input_fn=predict_input_fn))
|
||||
self.assertAllClose([[0.576117, 0.211942, 0.211942]],
|
||||
[pred['probabilities'] for pred in predictions])
|
||||
|
||||
def testRegression(self):
|
||||
"""Tests regression using matrix data as input."""
|
||||
head_fn = head_lib._regression_head(
|
||||
label_dimension=1,
|
||||
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=5,
|
||||
max_nodes=1000,
|
||||
num_classes=1,
|
||||
num_features=13,
|
||||
regression=True,
|
||||
split_after_samples=20)
|
||||
|
||||
regressor = random_forest.CoreTensorForestEstimator(
|
||||
hparams.fill(), head=head_fn)
|
||||
|
||||
input_fn, predict_input_fn = _get_regression_input_fns()
|
||||
|
||||
regressor.train(input_fn=input_fn, steps=100)
|
||||
res = regressor.evaluate(input_fn=input_fn, steps=10)
|
||||
self.assertGreaterEqual(0.1, res['loss'])
|
||||
|
||||
predictions = list(regressor.predict(input_fn=predict_input_fn))
|
||||
self.assertAllClose(
|
||||
[[24.]], [pred['predictions'] for pred in predictions], atol=1)
|
||||
|
||||
def testWithFeatureColumns(self):
|
||||
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=3,
|
||||
max_nodes=1000,
|
||||
num_classes=3,
|
||||
num_features=4,
|
||||
split_after_samples=20,
|
||||
inference_tree_paths=True)
|
||||
|
||||
est = random_forest.CoreTensorForestEstimator(
|
||||
hparams.fill(),
|
||||
head=head_fn,
|
||||
feature_columns=[core_feature_column.numeric_column('x')])
|
||||
|
||||
iris = base.load_iris()
|
||||
data = {'x': iris.data.astype(np.float32)}
|
||||
labels = iris.target.astype(np.int32)
|
||||
|
||||
input_fn = numpy_io.numpy_input_fn(
|
||||
x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
|
||||
|
||||
est.train(input_fn=input_fn, steps=100)
|
||||
res = est.evaluate(input_fn=input_fn, steps=1)
|
||||
|
||||
self.assertEqual(1.0, res['accuracy'])
|
||||
self.assertAllClose(0.55144483, res['loss'])
|
||||
|
||||
def testAutofillsClassificationHead(self):
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=3,
|
||||
max_nodes=1000,
|
||||
num_classes=3,
|
||||
num_features=4,
|
||||
split_after_samples=20,
|
||||
inference_tree_paths=True)
|
||||
|
||||
est = random_forest.CoreTensorForestEstimator(hparams.fill())
|
||||
|
||||
input_fn, _ = _get_classification_input_fns()
|
||||
|
||||
est.train(input_fn=input_fn, steps=100)
|
||||
res = est.evaluate(input_fn=input_fn, steps=1)
|
||||
|
||||
self.assertEqual(1.0, res['accuracy'])
|
||||
self.assertAllClose(0.55144483, res['loss'])
|
||||
|
||||
def testAutofillsRegressionHead(self):
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=5,
|
||||
max_nodes=1000,
|
||||
num_classes=1,
|
||||
num_features=13,
|
||||
regression=True,
|
||||
split_after_samples=20)
|
||||
|
||||
regressor = random_forest.CoreTensorForestEstimator(hparams.fill())
|
||||
|
||||
input_fn, predict_input_fn = _get_regression_input_fns()
|
||||
|
||||
regressor.train(input_fn=input_fn, steps=100)
|
||||
res = regressor.evaluate(input_fn=input_fn, steps=10)
|
||||
self.assertGreaterEqual(0.1, res['loss'])
|
||||
|
||||
predictions = list(regressor.predict(input_fn=predict_input_fn))
|
||||
self.assertAllClose(
|
||||
[[24.]], [pred['predictions'] for pred in predictions], atol=1)
|
||||
|
||||
def testAdditionalOutputs(self):
|
||||
"""Tests multi-class classification using matrix data as input."""
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=1,
|
||||
max_nodes=100,
|
||||
num_classes=3,
|
||||
num_features=4,
|
||||
split_after_samples=20,
|
||||
inference_tree_paths=True)
|
||||
classifier = random_forest.CoreTensorForestEstimator(
|
||||
hparams.fill(), keys_column='keys', include_all_in_serving=True)
|
||||
|
||||
iris = base.load_iris()
|
||||
data = iris.data.astype(np.float32)
|
||||
labels = iris.target.astype(np.int32)
|
||||
|
||||
input_fn = numpy_io.numpy_input_fn(
|
||||
x={
|
||||
'x': data,
|
||||
'keys': np.arange(len(iris.data)).reshape(150, 1)
|
||||
},
|
||||
y=labels,
|
||||
batch_size=10,
|
||||
num_epochs=1,
|
||||
shuffle=False)
|
||||
|
||||
classifier.train(input_fn=input_fn, steps=100)
|
||||
predictions = list(classifier.predict(input_fn=input_fn))
|
||||
# Check that there is a key column, tree paths and var.
|
||||
for pred in predictions:
|
||||
self.assertTrue('keys' in pred)
|
||||
self.assertTrue('tree_paths' in pred)
|
||||
self.assertTrue('prediction_variance' in pred)
|
||||
|
||||
def _assert_checkpoint(self, model_dir, global_step):
|
||||
reader = checkpoint_utils.load_checkpoint(model_dir)
|
||||
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
|
||||
|
||||
def testEarlyStopping(self):
|
||||
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_trees=3,
|
||||
max_nodes=1000,
|
||||
num_classes=3,
|
||||
num_features=4,
|
||||
split_after_samples=20,
|
||||
inference_tree_paths=True)
|
||||
|
||||
est = random_forest.CoreTensorForestEstimator(
|
||||
hparams.fill(),
|
||||
head=head_fn,
|
||||
# Set a crazy threshold - 30% loss change.
|
||||
early_stopping_loss_threshold=0.3,
|
||||
early_stopping_rounds=2)
|
||||
|
||||
input_fn, _ = _get_classification_input_fns()
|
||||
est.train(input_fn=input_fn, steps=100)
|
||||
# We stopped early.
|
||||
self._assert_checkpoint(est.model_dir, global_step=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user