Providing a core estimator interface over a contrib tensorforest.
PiperOrigin-RevId: 208658097
This commit is contained in:
parent
d7f93284c8
commit
b03f732b3f
@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
||||||
|
|
||||||
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
||||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||||
|
from tensorflow.python.estimator import estimator as core_estimator
|
||||||
|
from tensorflow.python.estimator.canned import head as core_head_lib
|
||||||
|
from tensorflow.python.estimator.export.export_output import PredictOutput
|
||||||
|
from tensorflow.python.feature_column import feature_column as fc_core
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops.losses import losses
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.summary import summary
|
from tensorflow.python.summary import summary
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
|
||||||
KEYS_NAME = 'keys'
|
KEYS_NAME = 'keys'
|
||||||
LOSS_NAME = 'rf_training_loss'
|
LOSS_NAME = 'rf_training_loss'
|
||||||
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
|
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
|
||||||
@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all'
|
|||||||
EPSILON = 0.000001
|
EPSILON = 0.000001
|
||||||
|
|
||||||
|
|
||||||
|
class ModelBuilderOutputType(object):
|
||||||
|
MODEL_FN_OPS = 0
|
||||||
|
ESTIMATOR_SPEC = 1
|
||||||
|
|
||||||
|
|
||||||
class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
|
class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
|
||||||
|
|
||||||
def __init__(self, op_dict):
|
def __init__(self, op_dict):
|
||||||
@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
|
|||||||
run_context.request_stop()
|
run_context.request_stop()
|
||||||
|
|
||||||
|
|
||||||
def get_default_head(params, weights_name, name=None):
|
def _get_default_head(params, weights_name, output_type, name=None):
|
||||||
if params.regression:
|
"""Creates a default head based on a type of a problem."""
|
||||||
return head_lib.regression_head(
|
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||||
weight_column_name=weights_name,
|
if params.regression:
|
||||||
label_dimension=params.num_outputs,
|
return head_lib.regression_head(
|
||||||
enable_centered_bias=False,
|
weight_column_name=weights_name,
|
||||||
head_name=name)
|
label_dimension=params.num_outputs,
|
||||||
|
enable_centered_bias=False,
|
||||||
|
head_name=name)
|
||||||
|
else:
|
||||||
|
return head_lib.multi_class_head(
|
||||||
|
params.num_classes,
|
||||||
|
weight_column_name=weights_name,
|
||||||
|
enable_centered_bias=False,
|
||||||
|
head_name=name)
|
||||||
else:
|
else:
|
||||||
return head_lib.multi_class_head(
|
if params.regression:
|
||||||
params.num_classes,
|
return core_head_lib._regression_head( # pylint:disable=protected-access
|
||||||
weight_column_name=weights_name,
|
weight_column=weights_name,
|
||||||
enable_centered_bias=False,
|
label_dimension=params.num_outputs,
|
||||||
head_name=name)
|
name=name,
|
||||||
|
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||||
|
else:
|
||||||
|
return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
|
||||||
|
n_classes=params.num_classes,
|
||||||
|
weight_column=weights_name,
|
||||||
|
name=name,
|
||||||
|
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||||
|
|
||||||
def get_model_fn(params,
|
def get_model_fn(params,
|
||||||
graph_builder_class,
|
graph_builder_class,
|
||||||
@ -135,19 +156,27 @@ def get_model_fn(params,
|
|||||||
report_feature_importances=False,
|
report_feature_importances=False,
|
||||||
local_eval=False,
|
local_eval=False,
|
||||||
head_scope=None,
|
head_scope=None,
|
||||||
include_all_in_serving=False):
|
include_all_in_serving=False,
|
||||||
|
output_type=ModelBuilderOutputType.MODEL_FN_OPS):
|
||||||
"""Return a model function given a way to construct a graph builder."""
|
"""Return a model function given a way to construct a graph builder."""
|
||||||
if model_head is None:
|
if model_head is None:
|
||||||
model_head = get_default_head(params, weights_name)
|
model_head = _get_default_head(params, weights_name, output_type)
|
||||||
|
|
||||||
def _model_fn(features, labels, mode):
|
def _model_fn(features, labels, mode):
|
||||||
"""Function that returns predictions, training loss, and training op."""
|
"""Function that returns predictions, training loss, and training op."""
|
||||||
|
|
||||||
if (isinstance(features, ops.Tensor) or
|
if (isinstance(features, ops.Tensor) or
|
||||||
isinstance(features, sparse_tensor.SparseTensor)):
|
isinstance(features, sparse_tensor.SparseTensor)):
|
||||||
features = {'features': features}
|
features = {'features': features}
|
||||||
if feature_columns:
|
if feature_columns:
|
||||||
features = features.copy()
|
features = features.copy()
|
||||||
features.update(layers.transform_features(features, feature_columns))
|
|
||||||
|
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||||
|
features.update(layers.transform_features(features, feature_columns))
|
||||||
|
else:
|
||||||
|
for fc in feature_columns:
|
||||||
|
tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access
|
||||||
|
features[fc.name] = tensor
|
||||||
|
|
||||||
weights = None
|
weights = None
|
||||||
if weights_name and weights_name in features:
|
if weights_name and weights_name in features:
|
||||||
@ -201,52 +230,95 @@ def get_model_fn(params,
|
|||||||
def _train_fn(unused_loss):
|
def _train_fn(unused_loss):
|
||||||
return training_graph
|
return training_graph
|
||||||
|
|
||||||
model_ops = model_head.create_model_fn_ops(
|
|
||||||
features=features,
|
|
||||||
labels=labels,
|
|
||||||
mode=mode,
|
|
||||||
train_op_fn=_train_fn,
|
|
||||||
logits=logits,
|
|
||||||
scope=head_scope)
|
|
||||||
|
|
||||||
# Ops are run in lexigraphical order of their keys. Run the resource
|
# Ops are run in lexigraphical order of their keys. Run the resource
|
||||||
# clean-up op last.
|
# clean-up op last.
|
||||||
all_handles = graph_builder.get_all_resource_handles()
|
all_handles = graph_builder.get_all_resource_handles()
|
||||||
ops_at_end = {
|
ops_at_end = {
|
||||||
'9: clean up resources': control_flow_ops.group(
|
'9: clean up resources':
|
||||||
*[resource_variable_ops.destroy_resource_op(handle)
|
control_flow_ops.group(*[
|
||||||
for handle in all_handles])}
|
resource_variable_ops.destroy_resource_op(handle)
|
||||||
|
for handle in all_handles
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
if report_feature_importances:
|
if report_feature_importances:
|
||||||
ops_at_end['1: feature_importances'] = (
|
ops_at_end['1: feature_importances'] = (
|
||||||
graph_builder.feature_importances())
|
graph_builder.feature_importances())
|
||||||
|
|
||||||
training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
|
training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)]
|
||||||
|
|
||||||
if early_stopping_rounds:
|
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||||
training_hooks.append(
|
model_ops = model_head.create_model_fn_ops(
|
||||||
TensorForestLossHook(
|
features=features,
|
||||||
early_stopping_rounds,
|
labels=labels,
|
||||||
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
mode=mode,
|
||||||
loss_op=model_ops.loss))
|
train_op_fn=_train_fn,
|
||||||
|
logits=logits,
|
||||||
|
scope=head_scope)
|
||||||
|
|
||||||
model_ops.training_hooks.extend(training_hooks)
|
if early_stopping_rounds:
|
||||||
|
training_hooks.append(
|
||||||
|
TensorForestLossHook(
|
||||||
|
early_stopping_rounds,
|
||||||
|
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
||||||
|
loss_op=model_ops.loss))
|
||||||
|
|
||||||
if keys is not None:
|
model_ops.training_hooks.extend(training_hooks)
|
||||||
model_ops.predictions[keys_name] = keys
|
|
||||||
|
|
||||||
if params.inference_tree_paths:
|
if keys is not None:
|
||||||
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
model_ops.predictions[keys_name] = keys
|
||||||
|
|
||||||
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
if params.inference_tree_paths:
|
||||||
if include_all_in_serving:
|
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||||
# In order to serve the variance we need to add the prediction dict
|
|
||||||
# to output_alternatives dict.
|
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||||
if not model_ops.output_alternatives:
|
|
||||||
model_ops.output_alternatives = {}
|
if include_all_in_serving:
|
||||||
model_ops.output_alternatives[ALL_SERVING_KEY] = (
|
# In order to serve the variance we need to add the prediction dict
|
||||||
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
|
# to output_alternatives dict.
|
||||||
return model_ops
|
if not model_ops.output_alternatives:
|
||||||
|
model_ops.output_alternatives = {}
|
||||||
|
model_ops.output_alternatives[ALL_SERVING_KEY] = (
|
||||||
|
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
|
||||||
|
|
||||||
|
return model_ops
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Estimator spec
|
||||||
|
estimator_spec = model_head.create_estimator_spec(
|
||||||
|
features=features,
|
||||||
|
mode=mode,
|
||||||
|
labels=labels,
|
||||||
|
train_op_fn=_train_fn,
|
||||||
|
logits=logits)
|
||||||
|
|
||||||
|
if early_stopping_rounds:
|
||||||
|
training_hooks.append(
|
||||||
|
TensorForestLossHook(
|
||||||
|
early_stopping_rounds,
|
||||||
|
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
||||||
|
loss_op=estimator_spec.loss))
|
||||||
|
|
||||||
|
estimator_spec = estimator_spec._replace(
|
||||||
|
training_hooks=training_hooks + list(estimator_spec.training_hooks))
|
||||||
|
if keys is not None:
|
||||||
|
estimator_spec.predictions[keys_name] = keys
|
||||||
|
if params.inference_tree_paths:
|
||||||
|
estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||||
|
estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||||
|
|
||||||
|
if include_all_in_serving:
|
||||||
|
outputs = estimator_spec.export_outputs
|
||||||
|
if not outputs:
|
||||||
|
outputs = {}
|
||||||
|
outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)}
|
||||||
|
print(estimator_spec.export_outputs)
|
||||||
|
# In order to serve the variance we need to add the prediction dict
|
||||||
|
# to output_alternatives dict.
|
||||||
|
estimator_spec = estimator_spec._replace(export_outputs=outputs)
|
||||||
|
|
||||||
|
return estimator_spec
|
||||||
|
|
||||||
return _model_fn
|
return _model_fn
|
||||||
|
|
||||||
@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
|
|||||||
params,
|
params,
|
||||||
graph_builder_class,
|
graph_builder_class,
|
||||||
device_assigner,
|
device_assigner,
|
||||||
model_head=get_default_head(
|
model_head=_get_default_head(
|
||||||
params, weight_column, name='head{0}'.format(i)),
|
params,
|
||||||
|
weight_column,
|
||||||
|
name='head{0}'.format(i),
|
||||||
|
output_type=ModelBuilderOutputType.MODEL_FN_OPS),
|
||||||
weights_name=weight_column,
|
weights_name=weight_column,
|
||||||
keys_name=keys_column,
|
keys_name=keys_column,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
|
|||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
feature_engineering_fn=feature_engineering_fn)
|
feature_engineering_fn=feature_engineering_fn)
|
||||||
|
|
||||||
|
|
||||||
|
class CoreTensorForestEstimator(core_estimator.Estimator):
|
||||||
|
"""A CORE estimator that can train and evaluate a random forest.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||||
|
num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
|
||||||
|
|
||||||
|
# Estimator using the default graph builder.
|
||||||
|
estimator = CoreTensorForestEstimator(params, model_dir=model_dir)
|
||||||
|
|
||||||
|
# Or estimator using TrainingLossForest as the graph builder.
|
||||||
|
estimator = CoreTensorForestEstimator(
|
||||||
|
params, graph_builder_class=tensor_forest.TrainingLossForest,
|
||||||
|
model_dir=model_dir)
|
||||||
|
|
||||||
|
# Input builders
|
||||||
|
def input_fn_train: # returns x, y
|
||||||
|
...
|
||||||
|
def input_fn_eval: # returns x, y
|
||||||
|
...
|
||||||
|
estimator.train(input_fn=input_fn_train)
|
||||||
|
estimator.evaluate(input_fn=input_fn_eval)
|
||||||
|
|
||||||
|
# Predict returns an iterable of dicts.
|
||||||
|
results = list(estimator.predict(x=x))
|
||||||
|
prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
|
||||||
|
prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
params,
|
||||||
|
device_assigner=None,
|
||||||
|
model_dir=None,
|
||||||
|
feature_columns=None,
|
||||||
|
graph_builder_class=tensor_forest.RandomForestGraphs,
|
||||||
|
config=None,
|
||||||
|
weight_column=None,
|
||||||
|
keys_column=None,
|
||||||
|
feature_engineering_fn=None,
|
||||||
|
early_stopping_rounds=100,
|
||||||
|
early_stopping_loss_threshold=0.001,
|
||||||
|
num_trainers=1,
|
||||||
|
trainer_id=0,
|
||||||
|
report_feature_importances=False,
|
||||||
|
local_eval=False,
|
||||||
|
version=None,
|
||||||
|
head=None,
|
||||||
|
include_all_in_serving=False):
|
||||||
|
"""Initializes a TensorForestEstimator instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: ForestHParams object that holds random forest hyperparameters.
|
||||||
|
These parameters will be passed into `model_fn`.
|
||||||
|
device_assigner: An `object` instance that controls how trees get
|
||||||
|
assigned to devices. If `None`, will use
|
||||||
|
`tensor_forest.RandomForestDeviceAssigner`.
|
||||||
|
model_dir: Directory to save model parameters, graph, etc. To continue
|
||||||
|
training a previously saved model, load checkpoints saved to this
|
||||||
|
directory into an estimator.
|
||||||
|
feature_columns: An iterable containing all the feature columns used by
|
||||||
|
the model. All items in the set should be instances of classes derived
|
||||||
|
from `_FeatureColumn`.
|
||||||
|
graph_builder_class: An `object` instance that defines how TF graphs for
|
||||||
|
random forest training and inference are built. By default will use
|
||||||
|
`tensor_forest.RandomForestGraphs`. Can be overridden by version
|
||||||
|
kwarg.
|
||||||
|
config: `RunConfig` object to configure the runtime settings.
|
||||||
|
weight_column: A string defining feature column name representing
|
||||||
|
weights. Will be multiplied by the loss of the example. Used to
|
||||||
|
downweight or boost examples during training.
|
||||||
|
keys_column: A string naming one of the features to strip out and
|
||||||
|
pass through into the inference/eval results dict. Useful for
|
||||||
|
associating specific examples with their prediction.
|
||||||
|
feature_engineering_fn: Feature engineering function. Takes features and
|
||||||
|
labels which are the output of `input_fn` and returns features and
|
||||||
|
labels which will be fed into the model.
|
||||||
|
early_stopping_rounds: Allows training to terminate early if the forest is
|
||||||
|
no longer growing. 100 by default. Set to a Falsy value to disable
|
||||||
|
the default training hook.
|
||||||
|
early_stopping_loss_threshold: Percentage (as fraction) that loss must
|
||||||
|
improve by within early_stopping_rounds steps, otherwise training will
|
||||||
|
terminate.
|
||||||
|
num_trainers: Number of training jobs, which will partition trees
|
||||||
|
among them.
|
||||||
|
trainer_id: Which trainer this instance is.
|
||||||
|
report_feature_importances: If True, print out feature importances
|
||||||
|
during evaluation.
|
||||||
|
local_eval: If True, don't use a device assigner for eval. This is to
|
||||||
|
support some common setups where eval is done on a single machine, even
|
||||||
|
though training might be distributed.
|
||||||
|
version: Unused.
|
||||||
|
head: A heads_lib.Head object that calculates losses and such. If None,
|
||||||
|
one will be automatically created based on params.
|
||||||
|
include_all_in_serving: if True, allow preparation of the complete
|
||||||
|
prediction dict including the variance to be exported for serving with
|
||||||
|
the Servo lib; and it also requires calling export_savedmodel with
|
||||||
|
default_output_alternative_key=ALL_SERVING_KEY, i.e.
|
||||||
|
estimator.export_savedmodel(export_dir_base=your_export_dir,
|
||||||
|
serving_input_fn=your_export_input_fn,
|
||||||
|
default_output_alternative_key=ALL_SERVING_KEY)
|
||||||
|
if False, resort to default behavior, i.e. export scores and
|
||||||
|
probabilities but no variances. In this case
|
||||||
|
default_output_alternative_key should be None while calling
|
||||||
|
export_savedmodel().
|
||||||
|
Note, that due to backward compatibility we cannot always set
|
||||||
|
include_all_in_serving to True because in this case calling
|
||||||
|
export_saved_model() without
|
||||||
|
default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
|
||||||
|
saved_model_export_utils.get_output_alternatives() would raise
|
||||||
|
ValueError.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `TensorForestEstimator` instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(CoreTensorForestEstimator, self).__init__(
|
||||||
|
model_fn=get_model_fn(
|
||||||
|
params.fill(),
|
||||||
|
graph_builder_class,
|
||||||
|
device_assigner,
|
||||||
|
feature_columns=feature_columns,
|
||||||
|
model_head=head,
|
||||||
|
weights_name=weight_column,
|
||||||
|
keys_name=keys_column,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
early_stopping_loss_threshold=early_stopping_loss_threshold,
|
||||||
|
num_trainers=num_trainers,
|
||||||
|
trainer_id=trainer_id,
|
||||||
|
report_feature_importances=report_feature_importances,
|
||||||
|
local_eval=local_eval,
|
||||||
|
include_all_in_serving=include_all_in_serving,
|
||||||
|
output_type=ModelBuilderOutputType.ESTIMATOR_SPEC),
|
||||||
|
model_dir=model_dir,
|
||||||
|
config=config)
|
||||||
|
@ -23,7 +23,39 @@ import numpy as np
|
|||||||
from tensorflow.contrib.learn.python.learn.datasets import base
|
from tensorflow.contrib.learn.python.learn.datasets import base
|
||||||
from tensorflow.contrib.tensor_forest.client import random_forest
|
from tensorflow.contrib.tensor_forest.client import random_forest
|
||||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||||
|
from tensorflow.python.estimator.canned import head as head_lib
|
||||||
|
from tensorflow.python.estimator.inputs import numpy_io
|
||||||
|
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops.losses import losses
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import checkpoint_utils
|
||||||
|
|
||||||
|
|
||||||
|
def _get_classification_input_fns():
|
||||||
|
iris = base.load_iris()
|
||||||
|
data = iris.data.astype(np.float32)
|
||||||
|
labels = iris.target.astype(np.int32)
|
||||||
|
|
||||||
|
train_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
|
||||||
|
|
||||||
|
predict_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
|
||||||
|
return train_input_fn, predict_input_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _get_regression_input_fns():
|
||||||
|
boston = base.load_boston()
|
||||||
|
data = boston.data.astype(np.float32)
|
||||||
|
labels = boston.target.astype(np.int32)
|
||||||
|
|
||||||
|
train_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False)
|
||||||
|
|
||||||
|
predict_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
|
||||||
|
return train_input_fn, predict_input_fn
|
||||||
|
|
||||||
|
|
||||||
class TensorForestTrainerTests(test.TestCase):
|
class TensorForestTrainerTests(test.TestCase):
|
||||||
@ -39,18 +71,22 @@ class TensorForestTrainerTests(test.TestCase):
|
|||||||
inference_tree_paths=True)
|
inference_tree_paths=True)
|
||||||
classifier = random_forest.TensorForestEstimator(hparams.fill())
|
classifier = random_forest.TensorForestEstimator(hparams.fill())
|
||||||
|
|
||||||
iris = base.load_iris()
|
input_fn, predict_input_fn = _get_classification_input_fns()
|
||||||
data = iris.data.astype(np.float32)
|
classifier.fit(input_fn=input_fn, steps=100)
|
||||||
labels = iris.target.astype(np.int32)
|
res = classifier.evaluate(input_fn=input_fn, steps=10)
|
||||||
|
|
||||||
classifier.fit(x=data, y=labels, steps=100, batch_size=50)
|
self.assertEqual(1.0, res['accuracy'])
|
||||||
classifier.evaluate(x=data, y=labels, steps=10)
|
self.assertAllClose(0.55144483, res['loss'])
|
||||||
|
|
||||||
|
predictions = list(classifier.predict(input_fn=predict_input_fn))
|
||||||
|
self.assertAllClose([[0.576117, 0.211942, 0.211942]],
|
||||||
|
[pred['probabilities'] for pred in predictions])
|
||||||
|
|
||||||
def testRegression(self):
|
def testRegression(self):
|
||||||
"""Tests multi-class classification using matrix data as input."""
|
"""Tests regression using matrix data as input."""
|
||||||
|
|
||||||
hparams = tensor_forest.ForestHParams(
|
hparams = tensor_forest.ForestHParams(
|
||||||
num_trees=3,
|
num_trees=5,
|
||||||
max_nodes=1000,
|
max_nodes=1000,
|
||||||
num_classes=1,
|
num_classes=1,
|
||||||
num_features=13,
|
num_features=13,
|
||||||
@ -59,12 +95,261 @@ class TensorForestTrainerTests(test.TestCase):
|
|||||||
|
|
||||||
regressor = random_forest.TensorForestEstimator(hparams.fill())
|
regressor = random_forest.TensorForestEstimator(hparams.fill())
|
||||||
|
|
||||||
boston = base.load_boston()
|
input_fn, predict_input_fn = _get_regression_input_fns()
|
||||||
data = boston.data.astype(np.float32)
|
|
||||||
labels = boston.target.astype(np.int32)
|
|
||||||
|
|
||||||
regressor.fit(x=data, y=labels, steps=100, batch_size=50)
|
regressor.fit(input_fn=input_fn, steps=100)
|
||||||
regressor.evaluate(x=data, y=labels, steps=10)
|
res = regressor.evaluate(input_fn=input_fn, steps=10)
|
||||||
|
self.assertGreaterEqual(0.1, res['loss'])
|
||||||
|
|
||||||
|
predictions = list(regressor.predict(input_fn=predict_input_fn))
|
||||||
|
self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1)
|
||||||
|
|
||||||
|
def testAdditionalOutputs(self):
|
||||||
|
"""Tests multi-class classification using matrix data as input."""
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=1,
|
||||||
|
max_nodes=100,
|
||||||
|
num_classes=3,
|
||||||
|
num_features=4,
|
||||||
|
split_after_samples=20,
|
||||||
|
inference_tree_paths=True)
|
||||||
|
classifier = random_forest.TensorForestEstimator(
|
||||||
|
hparams.fill(), keys_column='keys', include_all_in_serving=True)
|
||||||
|
|
||||||
|
iris = base.load_iris()
|
||||||
|
data = iris.data.astype(np.float32)
|
||||||
|
labels = iris.target.astype(np.int32)
|
||||||
|
|
||||||
|
input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={
|
||||||
|
'x': data,
|
||||||
|
'keys': np.arange(len(iris.data)).reshape(150, 1)
|
||||||
|
},
|
||||||
|
y=labels,
|
||||||
|
batch_size=10,
|
||||||
|
num_epochs=1,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
classifier.fit(input_fn=input_fn, steps=100)
|
||||||
|
predictions = list(classifier.predict(input_fn=input_fn))
|
||||||
|
# Check that there is a key column, tree paths and var.
|
||||||
|
for pred in predictions:
|
||||||
|
self.assertTrue('keys' in pred)
|
||||||
|
self.assertTrue('tree_paths' in pred)
|
||||||
|
self.assertTrue('prediction_variance' in pred)
|
||||||
|
|
||||||
|
def _assert_checkpoint(self, model_dir, global_step):
|
||||||
|
reader = checkpoint_utils.load_checkpoint(model_dir)
|
||||||
|
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
|
||||||
|
|
||||||
|
def testEarlyStopping(self):
|
||||||
|
"""Tests multi-class classification using matrix data as input."""
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=100,
|
||||||
|
max_nodes=10000,
|
||||||
|
num_classes=3,
|
||||||
|
num_features=4,
|
||||||
|
split_after_samples=20,
|
||||||
|
inference_tree_paths=True)
|
||||||
|
classifier = random_forest.TensorForestEstimator(
|
||||||
|
hparams.fill(),
|
||||||
|
# Set a crazy threshold - 30% loss change.
|
||||||
|
early_stopping_loss_threshold=0.3,
|
||||||
|
early_stopping_rounds=2)
|
||||||
|
|
||||||
|
input_fn, _ = _get_classification_input_fns()
|
||||||
|
classifier.fit(input_fn=input_fn, steps=100)
|
||||||
|
|
||||||
|
# We stopped early.
|
||||||
|
self._assert_checkpoint(classifier.model_dir, global_step=5)
|
||||||
|
|
||||||
|
|
||||||
|
class CoreTensorForestTests(test.TestCase):
|
||||||
|
|
||||||
|
def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self):
|
||||||
|
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||||
|
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||||
|
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=3,
|
||||||
|
max_nodes=1000,
|
||||||
|
num_classes=3,
|
||||||
|
num_features=4,
|
||||||
|
split_after_samples=20,
|
||||||
|
inference_tree_paths=True)
|
||||||
|
|
||||||
|
est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn)
|
||||||
|
|
||||||
|
input_fn, predict_input_fn = _get_classification_input_fns()
|
||||||
|
|
||||||
|
est.train(input_fn=input_fn, steps=100)
|
||||||
|
res = est.evaluate(input_fn=input_fn, steps=1)
|
||||||
|
|
||||||
|
self.assertEqual(1.0, res['accuracy'])
|
||||||
|
self.assertAllClose(0.55144483, res['loss'])
|
||||||
|
|
||||||
|
predictions = list(est.predict(input_fn=predict_input_fn))
|
||||||
|
self.assertAllClose([[0.576117, 0.211942, 0.211942]],
|
||||||
|
[pred['probabilities'] for pred in predictions])
|
||||||
|
|
||||||
|
def testRegression(self):
|
||||||
|
"""Tests regression using matrix data as input."""
|
||||||
|
head_fn = head_lib._regression_head(
|
||||||
|
label_dimension=1,
|
||||||
|
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||||
|
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=5,
|
||||||
|
max_nodes=1000,
|
||||||
|
num_classes=1,
|
||||||
|
num_features=13,
|
||||||
|
regression=True,
|
||||||
|
split_after_samples=20)
|
||||||
|
|
||||||
|
regressor = random_forest.CoreTensorForestEstimator(
|
||||||
|
hparams.fill(), head=head_fn)
|
||||||
|
|
||||||
|
input_fn, predict_input_fn = _get_regression_input_fns()
|
||||||
|
|
||||||
|
regressor.train(input_fn=input_fn, steps=100)
|
||||||
|
res = regressor.evaluate(input_fn=input_fn, steps=10)
|
||||||
|
self.assertGreaterEqual(0.1, res['loss'])
|
||||||
|
|
||||||
|
predictions = list(regressor.predict(input_fn=predict_input_fn))
|
||||||
|
self.assertAllClose(
|
||||||
|
[[24.]], [pred['predictions'] for pred in predictions], atol=1)
|
||||||
|
|
||||||
|
def testWithFeatureColumns(self):
|
||||||
|
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||||
|
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||||
|
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=3,
|
||||||
|
max_nodes=1000,
|
||||||
|
num_classes=3,
|
||||||
|
num_features=4,
|
||||||
|
split_after_samples=20,
|
||||||
|
inference_tree_paths=True)
|
||||||
|
|
||||||
|
est = random_forest.CoreTensorForestEstimator(
|
||||||
|
hparams.fill(),
|
||||||
|
head=head_fn,
|
||||||
|
feature_columns=[core_feature_column.numeric_column('x')])
|
||||||
|
|
||||||
|
iris = base.load_iris()
|
||||||
|
data = {'x': iris.data.astype(np.float32)}
|
||||||
|
labels = iris.target.astype(np.int32)
|
||||||
|
|
||||||
|
input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
|
||||||
|
|
||||||
|
est.train(input_fn=input_fn, steps=100)
|
||||||
|
res = est.evaluate(input_fn=input_fn, steps=1)
|
||||||
|
|
||||||
|
self.assertEqual(1.0, res['accuracy'])
|
||||||
|
self.assertAllClose(0.55144483, res['loss'])
|
||||||
|
|
||||||
|
def testAutofillsClassificationHead(self):
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=3,
|
||||||
|
max_nodes=1000,
|
||||||
|
num_classes=3,
|
||||||
|
num_features=4,
|
||||||
|
split_after_samples=20,
|
||||||
|
inference_tree_paths=True)
|
||||||
|
|
||||||
|
est = random_forest.CoreTensorForestEstimator(hparams.fill())
|
||||||
|
|
||||||
|
input_fn, _ = _get_classification_input_fns()
|
||||||
|
|
||||||
|
est.train(input_fn=input_fn, steps=100)
|
||||||
|
res = est.evaluate(input_fn=input_fn, steps=1)
|
||||||
|
|
||||||
|
self.assertEqual(1.0, res['accuracy'])
|
||||||
|
self.assertAllClose(0.55144483, res['loss'])
|
||||||
|
|
||||||
|
def testAutofillsRegressionHead(self):
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=5,
|
||||||
|
max_nodes=1000,
|
||||||
|
num_classes=1,
|
||||||
|
num_features=13,
|
||||||
|
regression=True,
|
||||||
|
split_after_samples=20)
|
||||||
|
|
||||||
|
regressor = random_forest.CoreTensorForestEstimator(hparams.fill())
|
||||||
|
|
||||||
|
input_fn, predict_input_fn = _get_regression_input_fns()
|
||||||
|
|
||||||
|
regressor.train(input_fn=input_fn, steps=100)
|
||||||
|
res = regressor.evaluate(input_fn=input_fn, steps=10)
|
||||||
|
self.assertGreaterEqual(0.1, res['loss'])
|
||||||
|
|
||||||
|
predictions = list(regressor.predict(input_fn=predict_input_fn))
|
||||||
|
self.assertAllClose(
|
||||||
|
[[24.]], [pred['predictions'] for pred in predictions], atol=1)
|
||||||
|
|
||||||
|
def testAdditionalOutputs(self):
|
||||||
|
"""Tests multi-class classification using matrix data as input."""
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=1,
|
||||||
|
max_nodes=100,
|
||||||
|
num_classes=3,
|
||||||
|
num_features=4,
|
||||||
|
split_after_samples=20,
|
||||||
|
inference_tree_paths=True)
|
||||||
|
classifier = random_forest.CoreTensorForestEstimator(
|
||||||
|
hparams.fill(), keys_column='keys', include_all_in_serving=True)
|
||||||
|
|
||||||
|
iris = base.load_iris()
|
||||||
|
data = iris.data.astype(np.float32)
|
||||||
|
labels = iris.target.astype(np.int32)
|
||||||
|
|
||||||
|
input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={
|
||||||
|
'x': data,
|
||||||
|
'keys': np.arange(len(iris.data)).reshape(150, 1)
|
||||||
|
},
|
||||||
|
y=labels,
|
||||||
|
batch_size=10,
|
||||||
|
num_epochs=1,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
classifier.train(input_fn=input_fn, steps=100)
|
||||||
|
predictions = list(classifier.predict(input_fn=input_fn))
|
||||||
|
# Check that there is a key column, tree paths and var.
|
||||||
|
for pred in predictions:
|
||||||
|
self.assertTrue('keys' in pred)
|
||||||
|
self.assertTrue('tree_paths' in pred)
|
||||||
|
self.assertTrue('prediction_variance' in pred)
|
||||||
|
|
||||||
|
def _assert_checkpoint(self, model_dir, global_step):
|
||||||
|
reader = checkpoint_utils.load_checkpoint(model_dir)
|
||||||
|
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
|
||||||
|
|
||||||
|
def testEarlyStopping(self):
|
||||||
|
head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||||
|
n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||||
|
|
||||||
|
hparams = tensor_forest.ForestHParams(
|
||||||
|
num_trees=3,
|
||||||
|
max_nodes=1000,
|
||||||
|
num_classes=3,
|
||||||
|
num_features=4,
|
||||||
|
split_after_samples=20,
|
||||||
|
inference_tree_paths=True)
|
||||||
|
|
||||||
|
est = random_forest.CoreTensorForestEstimator(
|
||||||
|
hparams.fill(),
|
||||||
|
head=head_fn,
|
||||||
|
# Set a crazy threshold - 30% loss change.
|
||||||
|
early_stopping_loss_threshold=0.3,
|
||||||
|
early_stopping_rounds=2)
|
||||||
|
|
||||||
|
input_fn, _ = _get_classification_input_fns()
|
||||||
|
est.train(input_fn=input_fn, steps=100)
|
||||||
|
# We stopped early.
|
||||||
|
self._assert_checkpoint(est.model_dir, global_step=5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user