Prepare variance to be exported for serving with the servo library.
PiperOrigin-RevId: 183851026
This commit is contained in:
parent
7149a2e2e2
commit
8f0e720777
@ -18,7 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
||||
@ -43,8 +43,8 @@ from tensorflow.python.training import training_util
|
||||
KEYS_NAME = 'keys'
|
||||
LOSS_NAME = 'rf_training_loss'
|
||||
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
|
||||
VARIANCE_PREDICTION_KEY = 'regression_variance'
|
||||
|
||||
VARIANCE_PREDICTION_KEY = 'prediction_variance'
|
||||
ALL_SERVING_KEY = 'tensorforest_all'
|
||||
EPSILON = 0.000001
|
||||
|
||||
|
||||
@ -134,7 +134,8 @@ def get_model_fn(params,
|
||||
trainer_id=0,
|
||||
report_feature_importances=False,
|
||||
local_eval=False,
|
||||
head_scope=None):
|
||||
head_scope=None,
|
||||
include_all_in_serving=False):
|
||||
"""Return a model function given a way to construct a graph builder."""
|
||||
if model_head is None:
|
||||
model_head = get_default_head(params, weights_name)
|
||||
@ -238,7 +239,13 @@ def get_model_fn(params,
|
||||
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||
|
||||
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||
|
||||
if include_all_in_serving:
|
||||
# In order to serve the variance we need to add the prediction dict
|
||||
# to output_alternatives dict.
|
||||
if not model_ops.output_alternatives:
|
||||
model_ops.output_alternatives = {}
|
||||
model_ops.output_alternatives[ALL_SERVING_KEY] = (
|
||||
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
|
||||
return model_ops
|
||||
|
||||
return _model_fn
|
||||
@ -293,7 +300,8 @@ class TensorForestEstimator(estimator.Estimator):
|
||||
report_feature_importances=False,
|
||||
local_eval=False,
|
||||
version=None,
|
||||
head=None):
|
||||
head=None,
|
||||
include_all_in_serving=False):
|
||||
"""Initializes a TensorForestEstimator instance.
|
||||
|
||||
Args:
|
||||
@ -339,6 +347,23 @@ class TensorForestEstimator(estimator.Estimator):
|
||||
version: Unused.
|
||||
head: A heads_lib.Head object that calculates losses and such. If None,
|
||||
one will be automatically created based on params.
|
||||
include_all_in_serving: if True, allow preparation of the complete
|
||||
prediction dict including the variance to be exported for serving with
|
||||
the Servo lib; and it also requires calling export_savedmodel with
|
||||
default_output_alternative_key=ALL_SERVING_KEY, i.e.
|
||||
estimator.export_savedmodel(export_dir_base=your_export_dir,
|
||||
serving_input_fn=your_export_input_fn,
|
||||
default_output_alternative_key=ALL_SERVING_KEY)
|
||||
if False, resort to default behavior, i.e. export scores and
|
||||
probabilities but no variances. In this case
|
||||
default_output_alternative_key should be None while calling
|
||||
export_savedmodel().
|
||||
Note, that due to backward compatibility we cannot always set
|
||||
include_all_in_serving to True because in this case calling
|
||||
export_saved_model() without
|
||||
default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
|
||||
saved_model_export_utils.get_output_alternatives() would raise
|
||||
ValueError.
|
||||
|
||||
Returns:
|
||||
A `TensorForestEstimator` instance.
|
||||
@ -357,7 +382,9 @@ class TensorForestEstimator(estimator.Estimator):
|
||||
num_trainers=num_trainers,
|
||||
trainer_id=trainer_id,
|
||||
report_feature_importances=report_feature_importances,
|
||||
local_eval=local_eval),
|
||||
local_eval=local_eval,
|
||||
include_all_in_serving=include_all_in_serving,
|
||||
),
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
Loading…
Reference in New Issue
Block a user