From 3f43cabbb1b0db790a5cd23d24ccd858e8ea631e Mon Sep 17 00:00:00 2001 From: Alan Yee Date: Tue, 5 Sep 2017 09:17:37 -0700 Subject: [PATCH] Update head.py (#12548) Update contrib lookup to core lookup --- .../contrib/learn/python/learn/estimators/head.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index c31d5d2d47d..225d8796785 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -24,7 +24,6 @@ import six from tensorflow.contrib import framework as framework_lib from tensorflow.contrib import layers as layers_lib -from tensorflow.contrib import lookup as lookup_lib from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -35,6 +34,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import nn @@ -1070,9 +1070,8 @@ class _MultiClassHead(_SingleHead): labels_tensor = _to_labels_tensor(labels, self._label_name) _check_no_sparse_tensor(labels_tensor) if self._label_keys: - table = lookup_lib.string_to_index_table_from_tensor( - mapping=self._label_keys, - name="label_id_lookup") + table = lookup_ops.index_table_from_tensor(self._label_keys, + name="label_id_lookup") return { "labels": labels_tensor, "label_ids": table.lookup(labels_tensor), @@ -1106,9 +1105,8 @@ class _MultiClassHead(_SingleHead): class_ids = math_ops.argmax( logits, 1, name=prediction_key.PredictionKey.CLASSES) if self._label_keys: - table = lookup_lib.index_to_string_table_from_tensor( - mapping=self._label_keys, - name="class_string_lookup") + table = lookup_ops.index_to_string_table_from_tensor( + self._label_keys, name="class_string_lookup") classes = table.lookup(class_ids) else: classes = class_ids