Prepare variance to be exported for serving with the servo library.
PiperOrigin-RevId: 183851026
This commit is contained in:
parent
7149a2e2e2
commit
8f0e720777
@ -18,7 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
|
||||||
@ -43,8 +43,8 @@ from tensorflow.python.training import training_util
|
|||||||
KEYS_NAME = 'keys'
|
KEYS_NAME = 'keys'
|
||||||
LOSS_NAME = 'rf_training_loss'
|
LOSS_NAME = 'rf_training_loss'
|
||||||
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
|
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
|
||||||
VARIANCE_PREDICTION_KEY = 'regression_variance'
|
VARIANCE_PREDICTION_KEY = 'prediction_variance'
|
||||||
|
ALL_SERVING_KEY = 'tensorforest_all'
|
||||||
EPSILON = 0.000001
|
EPSILON = 0.000001
|
||||||
|
|
||||||
|
|
||||||
@ -134,7 +134,8 @@ def get_model_fn(params,
|
|||||||
trainer_id=0,
|
trainer_id=0,
|
||||||
report_feature_importances=False,
|
report_feature_importances=False,
|
||||||
local_eval=False,
|
local_eval=False,
|
||||||
head_scope=None):
|
head_scope=None,
|
||||||
|
include_all_in_serving=False):
|
||||||
"""Return a model function given a way to construct a graph builder."""
|
"""Return a model function given a way to construct a graph builder."""
|
||||||
if model_head is None:
|
if model_head is None:
|
||||||
model_head = get_default_head(params, weights_name)
|
model_head = get_default_head(params, weights_name)
|
||||||
@ -238,7 +239,13 @@ def get_model_fn(params,
|
|||||||
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||||
|
|
||||||
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||||
|
if include_all_in_serving:
|
||||||
|
# In order to serve the variance we need to add the prediction dict
|
||||||
|
# to output_alternatives dict.
|
||||||
|
if not model_ops.output_alternatives:
|
||||||
|
model_ops.output_alternatives = {}
|
||||||
|
model_ops.output_alternatives[ALL_SERVING_KEY] = (
|
||||||
|
constants.ProblemType.UNSPECIFIED, model_ops.predictions)
|
||||||
return model_ops
|
return model_ops
|
||||||
|
|
||||||
return _model_fn
|
return _model_fn
|
||||||
@ -293,7 +300,8 @@ class TensorForestEstimator(estimator.Estimator):
|
|||||||
report_feature_importances=False,
|
report_feature_importances=False,
|
||||||
local_eval=False,
|
local_eval=False,
|
||||||
version=None,
|
version=None,
|
||||||
head=None):
|
head=None,
|
||||||
|
include_all_in_serving=False):
|
||||||
"""Initializes a TensorForestEstimator instance.
|
"""Initializes a TensorForestEstimator instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -339,6 +347,23 @@ class TensorForestEstimator(estimator.Estimator):
|
|||||||
version: Unused.
|
version: Unused.
|
||||||
head: A heads_lib.Head object that calculates losses and such. If None,
|
head: A heads_lib.Head object that calculates losses and such. If None,
|
||||||
one will be automatically created based on params.
|
one will be automatically created based on params.
|
||||||
|
include_all_in_serving: if True, allow preparation of the complete
|
||||||
|
prediction dict including the variance to be exported for serving with
|
||||||
|
the Servo lib; and it also requires calling export_savedmodel with
|
||||||
|
default_output_alternative_key=ALL_SERVING_KEY, i.e.
|
||||||
|
estimator.export_savedmodel(export_dir_base=your_export_dir,
|
||||||
|
serving_input_fn=your_export_input_fn,
|
||||||
|
default_output_alternative_key=ALL_SERVING_KEY)
|
||||||
|
if False, resort to default behavior, i.e. export scores and
|
||||||
|
probabilities but no variances. In this case
|
||||||
|
default_output_alternative_key should be None while calling
|
||||||
|
export_savedmodel().
|
||||||
|
Note, that due to backward compatibility we cannot always set
|
||||||
|
include_all_in_serving to True because in this case calling
|
||||||
|
export_saved_model() without
|
||||||
|
default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
|
||||||
|
saved_model_export_utils.get_output_alternatives() would raise
|
||||||
|
ValueError.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `TensorForestEstimator` instance.
|
A `TensorForestEstimator` instance.
|
||||||
@ -357,7 +382,9 @@ class TensorForestEstimator(estimator.Estimator):
|
|||||||
num_trainers=num_trainers,
|
num_trainers=num_trainers,
|
||||||
trainer_id=trainer_id,
|
trainer_id=trainer_id,
|
||||||
report_feature_importances=report_feature_importances,
|
report_feature_importances=report_feature_importances,
|
||||||
local_eval=local_eval),
|
local_eval=local_eval,
|
||||||
|
include_all_in_serving=include_all_in_serving,
|
||||||
|
),
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
feature_engineering_fn=feature_engineering_fn)
|
feature_engineering_fn=feature_engineering_fn)
|
||||||
|
Loading…
Reference in New Issue
Block a user