Fix SavedModel export when predictions is a single tensor and output_alternatives not given
Change: 144470271
This commit is contained in:
parent
8f893368fc
commit
1c5120141b
@ -25,3 +25,4 @@ class PredictionKey(object):
|
|||||||
LOGISTIC = "logistic"
|
LOGISTIC = "logistic"
|
||||||
SCORES = "scores"
|
SCORES = "scores"
|
||||||
TOP_K = "top_k"
|
TOP_K = "top_k"
|
||||||
|
GENERIC = "output"
|
||||||
|
@ -167,6 +167,8 @@ def get_output_alternatives(
|
|||||||
# interpret the model as single-headed of unknown type.
|
# interpret the model as single-headed of unknown type.
|
||||||
default_problem_type = constants.ProblemType.UNSPECIFIED
|
default_problem_type = constants.ProblemType.UNSPECIFIED
|
||||||
default_outputs = model_fn_ops.predictions
|
default_outputs = model_fn_ops.predictions
|
||||||
|
if not isinstance(default_outputs, dict):
|
||||||
|
default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs}
|
||||||
actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY
|
actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY
|
||||||
output_alternatives = {actual_default_output_alternative_key:
|
output_alternatives = {actual_default_output_alternative_key:
|
||||||
(default_problem_type, default_outputs)}
|
(default_problem_type, default_outputs)}
|
||||||
|
@ -124,6 +124,21 @@ class SavedModelExportUtilsTest(test.TestCase):
|
|||||||
})
|
})
|
||||||
}, output_alternatives)
|
}, output_alternatives)
|
||||||
|
|
||||||
|
def test_get_output_alternatives_implicit_single(self):
|
||||||
|
prediction_tensor = constant_op.constant(["bogus"])
|
||||||
|
model_fn_ops = model_fn.ModelFnOps(
|
||||||
|
model_fn.ModeKeys.INFER,
|
||||||
|
predictions=prediction_tensor,
|
||||||
|
output_alternatives=None)
|
||||||
|
|
||||||
|
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
|
||||||
|
model_fn_ops)
|
||||||
|
self.assertEqual({
|
||||||
|
"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
|
||||||
|
"output": prediction_tensor
|
||||||
|
})
|
||||||
|
}, output_alternatives)
|
||||||
|
|
||||||
def test_build_all_signature_defs(self):
|
def test_build_all_signature_defs(self):
|
||||||
input_features = constant_op.constant(["10"])
|
input_features = constant_op.constant(["10"])
|
||||||
input_example = constant_op.constant(["11"])
|
input_example = constant_op.constant(["11"])
|
||||||
|
Loading…
Reference in New Issue
Block a user