parent
5f95081698
commit
3f43cabbb1
@ -24,7 +24,6 @@ import six
|
|||||||
|
|
||||||
from tensorflow.contrib import framework as framework_lib
|
from tensorflow.contrib import framework as framework_lib
|
||||||
from tensorflow.contrib import layers as layers_lib
|
from tensorflow.contrib import layers as layers_lib
|
||||||
from tensorflow.contrib import lookup as lookup_lib
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
@ -35,6 +34,7 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import metrics as metrics_lib
|
from tensorflow.python.ops import metrics as metrics_lib
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
@ -1070,8 +1070,7 @@ class _MultiClassHead(_SingleHead):
|
|||||||
labels_tensor = _to_labels_tensor(labels, self._label_name)
|
labels_tensor = _to_labels_tensor(labels, self._label_name)
|
||||||
_check_no_sparse_tensor(labels_tensor)
|
_check_no_sparse_tensor(labels_tensor)
|
||||||
if self._label_keys:
|
if self._label_keys:
|
||||||
table = lookup_lib.string_to_index_table_from_tensor(
|
table = lookup_ops.index_table_from_tensor(self._label_keys,
|
||||||
mapping=self._label_keys,
|
|
||||||
name="label_id_lookup")
|
name="label_id_lookup")
|
||||||
return {
|
return {
|
||||||
"labels": labels_tensor,
|
"labels": labels_tensor,
|
||||||
@ -1106,9 +1105,8 @@ class _MultiClassHead(_SingleHead):
|
|||||||
class_ids = math_ops.argmax(
|
class_ids = math_ops.argmax(
|
||||||
logits, 1, name=prediction_key.PredictionKey.CLASSES)
|
logits, 1, name=prediction_key.PredictionKey.CLASSES)
|
||||||
if self._label_keys:
|
if self._label_keys:
|
||||||
table = lookup_lib.index_to_string_table_from_tensor(
|
table = lookup_ops.index_to_string_table_from_tensor(
|
||||||
mapping=self._label_keys,
|
self._label_keys, name="class_string_lookup")
|
||||||
name="class_string_lookup")
|
|
||||||
classes = table.lookup(class_ids)
|
classes = table.lookup(class_ids)
|
||||||
else:
|
else:
|
||||||
classes = class_ids
|
classes = class_ids
|
||||||
|
Loading…
Reference in New Issue
Block a user