Update head.py (#12548)

Update contrib lookup to core lookup
This commit is contained in:
Alan Yee 2017-09-05 09:17:37 -07:00 committed by Martin Wicke
parent 5f95081698
commit 3f43cabbb1

View File

@ -24,7 +24,6 @@ import six
from tensorflow.contrib import framework as framework_lib
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib import lookup as lookup_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
@ -35,6 +34,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import nn
@ -1070,9 +1070,8 @@ class _MultiClassHead(_SingleHead):
labels_tensor = _to_labels_tensor(labels, self._label_name)
_check_no_sparse_tensor(labels_tensor)
if self._label_keys:
table = lookup_lib.string_to_index_table_from_tensor(
mapping=self._label_keys,
name="label_id_lookup")
table = lookup_ops.index_table_from_tensor(self._label_keys,
name="label_id_lookup")
return {
"labels": labels_tensor,
"label_ids": table.lookup(labels_tensor),
@ -1106,9 +1105,8 @@ class _MultiClassHead(_SingleHead):
class_ids = math_ops.argmax(
logits, 1, name=prediction_key.PredictionKey.CLASSES)
if self._label_keys:
table = lookup_lib.index_to_string_table_from_tensor(
mapping=self._label_keys,
name="class_string_lookup")
table = lookup_ops.index_to_string_table_from_tensor(
self._label_keys, name="class_string_lookup")
classes = table.lookup(class_ids)
else:
classes = class_ids