Adds a function in ModelFnOps to create an equivalent EstimatorSpec. This will allow existing custom estimators to easily switch to core estimator until we provide better support (like moving heads to core).

Change: 152133991
This commit is contained in:
Zakaria Haque 2017-04-04 07:25:26 -08:00 committed by TensorFlower Gardener
parent af21dee78f
commit e2271157eb
3 changed files with 381 additions and 1 deletions

View File

@ -836,6 +836,19 @@ py_test(
],
)
py_test(
name = "model_fn_test",
size = "small",
srcs = ["python/learn/estimators/model_fn_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)
py_test(
name = "multioutput_test",
size = "small",

View File

@ -25,10 +25,16 @@ import six
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import get_graph_from_inputs
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.python.estimator import model_fn as core_model_fn_lib
from tensorflow.python.estimator.export import export_output as core_export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import session_run_hook
@ -177,3 +183,85 @@ class ModelFnOps(
training_chief_hooks=training_chief_hooks,
training_hooks=training_hooks,
scaffold=scaffold)
def estimator_spec(self, mode, default_serving_output_alternative_key=None):
"""Creates an equivalent `EstimatorSpec`.
Args:
mode: One of `ModeKeys`. Specifies if this training, evaluation or
prediction.
default_serving_output_alternative_key: Required for multiple heads. If
you have multiple entries in `output_alternatives` dict (comparable to
multiple heads), `EstimatorSpec` requires a default head that will be
used if a Servo request does not explicitly mention which head to infer
on. Pass the key of the output alternative here that you want to
designate as default. A separate ExportOutpout for this default head
wil be added to the export_outputs dict with the special key
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is
already an enry in output_alternatives with this special key.
Returns:
Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps`
Raises:
ValueError: If problem type is unknown.
"""
def _scores(output_tensors):
scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
if scores is None:
scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES)
return scores
def _classes(output_tensors): # pylint: disable=missing-docstring
classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
if classes is None:
logging.warning(
'classes is None, Servo inference will not have class ids.')
return None
elif classes.dtype != dtypes.string:
# Servo classification can only serve string classes
logging.warning(
'classes is not string, Servo inference will not have class ids.')
return None
return classes
def _export_output(problem_type, predictions): # pylint: disable=missing-docstring
if problem_type == constants.ProblemType.LINEAR_REGRESSION:
return core_export_lib.RegressionOutput(_scores(predictions))
if (problem_type == constants.ProblemType.CLASSIFICATION or
problem_type == constants.ProblemType.LOGISTIC_REGRESSION):
return core_export_lib.ClassificationOutput(
scores=_scores(predictions), classes=_classes(predictions))
if problem_type == constants.ProblemType.UNSPECIFIED:
return core_export_lib.PredictOutput(predictions)
raise ValueError('Unknown problem_type=%s' % problem_type)
# Converts output_alternatives
export_outputs_dict = None
if self.output_alternatives:
output_alternatives = self.output_alternatives
# Adds default output_alternative if needed.
if (len(output_alternatives) > 1 and
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
output_alternatives):
output_alternatives = output_alternatives.copy()
output_alternatives[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
output_alternatives[default_serving_output_alternative_key])
export_outputs_dict = {key: _export_output(*val) for key, val in
output_alternatives.items()}
return core_model_fn_lib.EstimatorSpec(
mode=mode,
predictions=self.predictions,
loss=self.loss,
train_op=self.train_op,
eval_metric_ops=self.eval_metric_ops,
export_outputs=export_outputs_dict,
training_chief_hooks=self.training_chief_hooks,
training_hooks=self.training_hooks,
scaffold=self.scaffold)

View File

@ -0,0 +1,279 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ModelFnOps tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.client import session
from tensorflow.python.estimator.export import export_output as core_export_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
class ModelFnopsTest(test.TestCase):
"""Multi-output tests."""
def create_predictions(self):
probabilities = constant_op.constant([1., 1., 1.])
scores = constant_op.constant([1., 2., 3.])
classes = constant_op.constant([b"0", b"1", b"2"])
return {
"probabilities": probabilities,
"scores": scores,
"classes": classes}
def create_model_fn_ops(self, predictions, output_alternatives,
mode=model_fn.ModeKeys.INFER):
return model_fn.ModelFnOps(
model_fn.ModeKeys.INFER,
predictions=predictions,
loss=constant_op.constant([1]),
train_op=control_flow_ops.no_op(),
eval_metric_ops={"metric_key": (control_flow_ops.no_op(),
control_flow_ops.no_op())},
# zzz
training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
training_hooks=[basic_session_run_hooks.StepCounterHook()],
output_alternatives=output_alternatives,
scaffold=monitored_session.Scaffold())
def assertEquals_except_export(self, model_fn_ops, estimator_spec):
self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions)
self.assertEqual(model_fn_ops.loss, estimator_spec.loss)
self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op)
self.assertEqual(model_fn_ops.eval_metric_ops,
estimator_spec.eval_metric_ops)
self.assertEqual(model_fn_ops.training_chief_hooks,
estimator_spec.training_chief_hooks)
self.assertEqual(model_fn_ops.training_hooks, estimator_spec.training_hooks)
self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold)
def testEstimatorSpec_except_export(self):
predictions = self.create_predictions()
model_fn_ops = self.create_model_fn_ops(predictions, None)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
def testEstimatorSpec_export_regression_with_scores(self):
predictions = self.create_predictions()
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
regression_output = estimator_spec.export_outputs["regression_head"]
self.assertTrue(isinstance(
regression_output, core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["scores"].eval(),
regression_output.value.eval())
def testEstimatorSpec_export_regression_with_probabilities(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["scores"]
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION,
output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
regression_output = estimator_spec.export_outputs["regression_head"]
self.assertTrue(isinstance(
regression_output, core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["probabilities"].eval(),
regression_output.value.eval())
def testEstimatorSpec_export_classsification(self):
predictions = self.create_predictions()
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
classification_output.scores.eval())
self.assertAllEqual(predictions["classes"].eval(),
classification_output.classes.eval())
def testEstimatorSpec_export_classsification_with_missing_scores(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["scores"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["probabilities"].eval(),
classification_output.scores.eval())
self.assertAllEqual(predictions["classes"].eval(),
classification_output.classes.eval())
def testEstimatorSpec_export_classsification_with_missing_scores_proba(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["scores"]
del output_alternatives_predictions["probabilities"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertIsNone(classification_output.scores)
self.assertAllEqual(predictions["classes"].eval(),
classification_output.classes.eval())
def testEstimatorSpec_export_classsification_with_missing_classes(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["classes"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
classification_output.scores.eval())
self.assertIsNone(classification_output.classes)
def testEstimatorSpec_export_classsification_with_nonstring_classes(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
output_alternatives_predictions["classes"] = constant_op.constant(
[1, 2, 3])
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
classification_output.scores.eval())
self.assertIsNone(classification_output.classes)
def testEstimatorSpec_export_logistic(self):
predictions = self.create_predictions()
output_alternatives = {"logistic_head": (
constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
logistic_output = estimator_spec.export_outputs["logistic_head"]
self.assertTrue(isinstance(logistic_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
logistic_output.scores.eval())
self.assertAllEqual(predictions["classes"].eval(),
logistic_output.classes.eval())
def testEstimatorSpec_export_unspecified(self):
predictions = self.create_predictions()
output_alternatives = {"unspecified_head": (
constants.ProblemType.UNSPECIFIED, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
unspecified_output = estimator_spec.export_outputs["unspecified_head"]
self.assertTrue(isinstance(unspecified_output,
core_export_lib.PredictOutput))
self.assertEqual(predictions, unspecified_output.outputs)
def testEstimatorSpec_export_multihead(self):
predictions = self.create_predictions()
output_alternatives = {
"regression_head": (
constants.ProblemType.LINEAR_REGRESSION, predictions),
"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
"regression_head")
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
regression_output = estimator_spec.export_outputs["regression_head"]
self.assertTrue(isinstance(
regression_output, core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["scores"].eval(),
regression_output.value.eval())
default_output = estimator_spec.export_outputs[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
self.assertTrue(isinstance(default_output,
core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["scores"].eval(),
default_output.value.eval())
if __name__ == "__main__":
test.main()