diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index a998ac1e111..4abcc20ed33 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers - +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib @@ -43,8 +43,8 @@ from tensorflow.python.training import training_util KEYS_NAME = 'keys' LOSS_NAME = 'rf_training_loss' TREE_PATHS_PREDICTION_KEY = 'tree_paths' -VARIANCE_PREDICTION_KEY = 'regression_variance' - +VARIANCE_PREDICTION_KEY = 'prediction_variance' +ALL_SERVING_KEY = 'tensorforest_all' EPSILON = 0.000001 @@ -134,7 +134,8 @@ def get_model_fn(params, trainer_id=0, report_feature_importances=False, local_eval=False, - head_scope=None): + head_scope=None, + include_all_in_serving=False): """Return a model function given a way to construct a graph builder.""" if model_head is None: model_head = get_default_head(params, weights_name) @@ -238,7 +239,13 @@ def get_model_fn(params, model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance - + if include_all_in_serving: + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + if not model_ops.output_alternatives: + model_ops.output_alternatives = {} + model_ops.output_alternatives[ALL_SERVING_KEY] = ( + constants.ProblemType.UNSPECIFIED, model_ops.predictions) return model_ops return _model_fn @@ -293,7 +300,8 @@ class TensorForestEstimator(estimator.Estimator): report_feature_importances=False, local_eval=False, version=None, - head=None): + head=None, + include_all_in_serving=False): """Initializes a TensorForestEstimator instance. Args: @@ -339,6 +347,23 @@ class TensorForestEstimator(estimator.Estimator): version: Unused. head: A heads_lib.Head object that calculates losses and such. If None, one will be automatically created based on params. + include_all_in_serving: if True, allow preparation of the complete + prediction dict including the variance to be exported for serving with + the Servo lib; and it also requires calling export_savedmodel with + default_output_alternative_key=ALL_SERVING_KEY, i.e. + estimator.export_savedmodel(export_dir_base=your_export_dir, + serving_input_fn=your_export_input_fn, + default_output_alternative_key=ALL_SERVING_KEY) + if False, resort to default behavior, i.e. export scores and + probabilities but no variances. In this case + default_output_alternative_key should be None while calling + export_savedmodel(). + Note, that due to backward compatibility we cannot always set + include_all_in_serving to True because in this case calling + export_saved_model() without + default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the + saved_model_export_utils.get_output_alternatives() would raise + ValueError. Returns: A `TensorForestEstimator` instance. @@ -357,7 +382,9 @@ class TensorForestEstimator(estimator.Estimator): num_trainers=num_trainers, trainer_id=trainer_id, report_feature_importances=report_feature_importances, - local_eval=local_eval), + local_eval=local_eval, + include_all_in_serving=include_all_in_serving, + ), model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn)