Adds a function in ModelFnOps to create an equivalent EstimatorSpec. This will allow existing custom estimators to easily switch to core estimator until we provide better support (like moving heads to core).
Change: 152133991
This commit is contained in:
parent
af21dee78f
commit
e2271157eb
@ -836,6 +836,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_fn_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/estimators/model_fn_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multioutput_test",
|
||||
size = "small",
|
||||
|
@ -25,10 +25,16 @@ import six
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework import get_graph_from_inputs
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.python.estimator import model_fn as core_model_fn_lib
|
||||
from tensorflow.python.estimator.export import export_output as core_export_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
|
||||
@ -177,3 +183,85 @@ class ModelFnOps(
|
||||
training_chief_hooks=training_chief_hooks,
|
||||
training_hooks=training_hooks,
|
||||
scaffold=scaffold)
|
||||
|
||||
def estimator_spec(self, mode, default_serving_output_alternative_key=None):
|
||||
"""Creates an equivalent `EstimatorSpec`.
|
||||
|
||||
Args:
|
||||
mode: One of `ModeKeys`. Specifies if this training, evaluation or
|
||||
prediction.
|
||||
default_serving_output_alternative_key: Required for multiple heads. If
|
||||
you have multiple entries in `output_alternatives` dict (comparable to
|
||||
multiple heads), `EstimatorSpec` requires a default head that will be
|
||||
used if a Servo request does not explicitly mention which head to infer
|
||||
on. Pass the key of the output alternative here that you want to
|
||||
designate as default. A separate ExportOutpout for this default head
|
||||
wil be added to the export_outputs dict with the special key
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is
|
||||
already an enry in output_alternatives with this special key.
|
||||
|
||||
Returns:
|
||||
Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps`
|
||||
|
||||
Raises:
|
||||
ValueError: If problem type is unknown.
|
||||
"""
|
||||
def _scores(output_tensors):
|
||||
scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
|
||||
if scores is None:
|
||||
scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES)
|
||||
return scores
|
||||
|
||||
def _classes(output_tensors): # pylint: disable=missing-docstring
|
||||
classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
|
||||
if classes is None:
|
||||
logging.warning(
|
||||
'classes is None, Servo inference will not have class ids.')
|
||||
return None
|
||||
elif classes.dtype != dtypes.string:
|
||||
# Servo classification can only serve string classes
|
||||
logging.warning(
|
||||
'classes is not string, Servo inference will not have class ids.')
|
||||
return None
|
||||
|
||||
return classes
|
||||
|
||||
def _export_output(problem_type, predictions): # pylint: disable=missing-docstring
|
||||
if problem_type == constants.ProblemType.LINEAR_REGRESSION:
|
||||
return core_export_lib.RegressionOutput(_scores(predictions))
|
||||
|
||||
if (problem_type == constants.ProblemType.CLASSIFICATION or
|
||||
problem_type == constants.ProblemType.LOGISTIC_REGRESSION):
|
||||
return core_export_lib.ClassificationOutput(
|
||||
scores=_scores(predictions), classes=_classes(predictions))
|
||||
|
||||
if problem_type == constants.ProblemType.UNSPECIFIED:
|
||||
return core_export_lib.PredictOutput(predictions)
|
||||
|
||||
raise ValueError('Unknown problem_type=%s' % problem_type)
|
||||
|
||||
# Converts output_alternatives
|
||||
export_outputs_dict = None
|
||||
if self.output_alternatives:
|
||||
output_alternatives = self.output_alternatives
|
||||
# Adds default output_alternative if needed.
|
||||
if (len(output_alternatives) > 1 and
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
|
||||
output_alternatives):
|
||||
output_alternatives = output_alternatives.copy()
|
||||
output_alternatives[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
|
||||
output_alternatives[default_serving_output_alternative_key])
|
||||
export_outputs_dict = {key: _export_output(*val) for key, val in
|
||||
output_alternatives.items()}
|
||||
|
||||
return core_model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
predictions=self.predictions,
|
||||
loss=self.loss,
|
||||
train_op=self.train_op,
|
||||
eval_metric_ops=self.eval_metric_ops,
|
||||
export_outputs=export_outputs_dict,
|
||||
training_chief_hooks=self.training_chief_hooks,
|
||||
training_hooks=self.training_hooks,
|
||||
scaffold=self.scaffold)
|
||||
|
@ -0,0 +1,279 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ModelFnOps tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.estimator.export import export_output as core_export_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
|
||||
|
||||
class ModelFnopsTest(test.TestCase):
|
||||
"""Multi-output tests."""
|
||||
|
||||
def create_predictions(self):
|
||||
probabilities = constant_op.constant([1., 1., 1.])
|
||||
scores = constant_op.constant([1., 2., 3.])
|
||||
classes = constant_op.constant([b"0", b"1", b"2"])
|
||||
return {
|
||||
"probabilities": probabilities,
|
||||
"scores": scores,
|
||||
"classes": classes}
|
||||
|
||||
def create_model_fn_ops(self, predictions, output_alternatives,
|
||||
mode=model_fn.ModeKeys.INFER):
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
model_fn.ModeKeys.INFER,
|
||||
predictions=predictions,
|
||||
loss=constant_op.constant([1]),
|
||||
train_op=control_flow_ops.no_op(),
|
||||
eval_metric_ops={"metric_key": (control_flow_ops.no_op(),
|
||||
control_flow_ops.no_op())},
|
||||
# zzz
|
||||
training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||
training_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||
output_alternatives=output_alternatives,
|
||||
scaffold=monitored_session.Scaffold())
|
||||
|
||||
def assertEquals_except_export(self, model_fn_ops, estimator_spec):
|
||||
self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions)
|
||||
self.assertEqual(model_fn_ops.loss, estimator_spec.loss)
|
||||
self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op)
|
||||
self.assertEqual(model_fn_ops.eval_metric_ops,
|
||||
estimator_spec.eval_metric_ops)
|
||||
self.assertEqual(model_fn_ops.training_chief_hooks,
|
||||
estimator_spec.training_chief_hooks)
|
||||
self.assertEqual(model_fn_ops.training_hooks, estimator_spec.training_hooks)
|
||||
self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold)
|
||||
|
||||
def testEstimatorSpec_except_export(self):
|
||||
predictions = self.create_predictions()
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, None)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
def testEstimatorSpec_export_regression_with_scores(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"regression_head": (
|
||||
constants.ProblemType.LINEAR_REGRESSION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
self.assertTrue(isinstance(
|
||||
regression_output, core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
regression_output.value.eval())
|
||||
|
||||
def testEstimatorSpec_export_regression_with_probabilities(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["scores"]
|
||||
output_alternatives = {"regression_head": (
|
||||
constants.ProblemType.LINEAR_REGRESSION,
|
||||
output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
self.assertTrue(isinstance(
|
||||
regression_output, core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["probabilities"].eval(),
|
||||
regression_output.value.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
classification_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_missing_scores(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["scores"]
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["probabilities"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
classification_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_missing_scores_proba(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["scores"]
|
||||
del output_alternatives_predictions["probabilities"]
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertIsNone(classification_output.scores)
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
classification_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_missing_classes(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["classes"]
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertIsNone(classification_output.classes)
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_nonstring_classes(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
output_alternatives_predictions["classes"] = constant_op.constant(
|
||||
[1, 2, 3])
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertIsNone(classification_output.classes)
|
||||
|
||||
def testEstimatorSpec_export_logistic(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"logistic_head": (
|
||||
constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
logistic_output = estimator_spec.export_outputs["logistic_head"]
|
||||
self.assertTrue(isinstance(logistic_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
logistic_output.scores.eval())
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
logistic_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_unspecified(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"unspecified_head": (
|
||||
constants.ProblemType.UNSPECIFIED, predictions)}
|
||||
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
unspecified_output = estimator_spec.export_outputs["unspecified_head"]
|
||||
self.assertTrue(isinstance(unspecified_output,
|
||||
core_export_lib.PredictOutput))
|
||||
self.assertEqual(predictions, unspecified_output.outputs)
|
||||
|
||||
def testEstimatorSpec_export_multihead(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {
|
||||
"regression_head": (
|
||||
constants.ProblemType.LINEAR_REGRESSION, predictions),
|
||||
"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
|
||||
"regression_head")
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
self.assertTrue(isinstance(
|
||||
regression_output, core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
regression_output.value.eval())
|
||||
|
||||
default_output = estimator_spec.export_outputs[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
self.assertTrue(isinstance(default_output,
|
||||
core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
default_output.value.eval())
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user