Add TopN class.

Change: 130099876
This commit is contained in:
A. Unique TensorFlower 2016-08-12 06:59:14 -08:00 committed by TensorFlower Gardener
parent 5c29a24fd6
commit 55b44e625c
6 changed files with 702 additions and 0 deletions

View File

@ -125,15 +125,24 @@ tf_custom_op_library(
deps = [":tree_utils"],
)
tf_custom_op_library(
name = "python/ops/_topn_ops.so",
srcs = [
"core/ops/topn_ops.cc",
],
)
py_library(
name = "ops_lib",
srcs = [
"__init__.py",
"python/ops/inference_ops.py",
"python/ops/topn_ops.py",
"python/ops/training_ops.py",
],
data = [
"python/ops/_inference_ops.so",
"python/ops/_topn_ops.so",
"python/ops/_training_ops.so",
],
srcs_version = "PY2AND3",
@ -286,3 +295,26 @@ cc_test(
"//third_party/eigen3",
],
)
py_library(
name = "topn_py",
srcs = ["python/topn.py"],
srcs_version = "PY2AND3",
deps = [
":constants",
":ops_lib",
],
)
py_test(
name = "topn_test",
size = "small",
srcs = ["python/topn_test.py"],
srcs_version = "PY2AND3",
deps = [
":topn_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

View File

@ -0,0 +1,265 @@
// Copyright 2016 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// The three Ops used to implement a TopN structure: Insert, Remove, and
// RefreshShortlist.
#include <algorithm>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
REGISTER_OP("TopNInsert")
.Input("ids: int64")
.Input("scores: float32")
.Input("new_ids: int64")
.Input("new_scores: float32")
.Output("shortlist_ids: int64")
.Output("update_ids: int64")
.Output("update_scores: float32")
.Doc(R"doc(
Outputs update Tensors for adding new_ids and new_scores to the shortlist.
ids:= A 1-D int64 tensor containing the ids on the shortlist (except for
ids[0], which is the current size of the shortlist.
scores:= A 1-D float32 tensor containing the scores on the shortlist.
new_ids:= A 1-D int64 tensor containing the new ids to add to the shortlist.
shortlist_ids:= A 1-D int64 tensor containing the ids of the shortlist entries
to update. Intended to be used with
tf.scatter_update(shortlist_scores, shortlist_ids, new_scores).
update_ids:= A 1-D int64 tensor containing ...
update_scores:= A 1-D float32 tensor containing ...
)doc");
class TopNInsert : public OpKernel {
public:
explicit TopNInsert(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& ids = context->input(0);
const Tensor& scores = context->input(1);
const Tensor& new_ids = context->input(2);
const Tensor& new_scores = context->input(3);
OP_REQUIRES(context, ids.shape().dims() == 1,
errors::InvalidArgument("ids should be one-dimensional"));
OP_REQUIRES(context, scores.shape().dims() == 1,
errors::InvalidArgument("scores should be one-dimensional"));
OP_REQUIRES(context, new_ids.shape().dims() == 1,
errors::InvalidArgument("new_ids should be one-dimensional"));
OP_REQUIRES(
context, new_scores.shape().dims() == 1,
errors::InvalidArgument("new_scores should be one-dimensional"));
OP_REQUIRES(
context, ids.shape().dim_size(0) == scores.shape().dim_size(0),
errors::InvalidArgument("ids and scores should be the same length"));
OP_REQUIRES(context,
new_ids.shape().dim_size(0) == new_scores.shape().dim_size(0),
errors::InvalidArgument(
"new_ids and new_scores should be the same length"));
const auto flat_ids = ids.unaligned_flat<int64>();
const auto flat_scores = scores.unaligned_flat<float>();
const auto flat_new_ids = new_ids.unaligned_flat<int64>();
const auto flat_new_scores = new_scores.unaligned_flat<float>();
const int num_updates = new_ids.shape().dim_size(0);
const int shortlist_max_size = ids.shape().dim_size(0) - 1;
int shortlist_size = std::max(0, static_cast<int>(flat_ids(0)));
int overflow = shortlist_size + num_updates - shortlist_max_size;
std::vector<std::tuple<int64, int64, float>> updates;
float score_cutoff = flat_scores(0);
if (overflow > 0) {
// Sort the *highest* overflow updates
std::vector<int> update_indices(num_updates);
for (int i = 0; i < num_updates; i++) {
update_indices[i] = i;
}
auto cmp = [&flat_new_scores](int a, int b) {
return flat_new_scores(a) > flat_new_scores(b);
};
std::sort(update_indices.begin(), update_indices.end(), cmp);
// Sort the *lowest* overflow shortlist entries
std::vector<int> shortlist_indices(shortlist_max_size + 1);
std::iota(shortlist_indices.begin() + 1, shortlist_indices.end(), 1);
auto cmp2 = [&flat_scores](int a, int b) {
return flat_scores(a) < flat_scores(b);
};
std::sort(shortlist_indices.begin() + 1, shortlist_indices.end(), cmp2);
int i = 0; // Points into update_indices
int j = 1; // Points into shortlist_indices
while (i < num_updates && j <= shortlist_max_size) {
VLOG(2) << "i = " << i;
VLOG(2) << "j = " << j;
VLOG(2) << "update_indices[i] = " << update_indices[i];
VLOG(2) << "shortlist_indices[j] = " << shortlist_indices[j];
VLOG(2) << "flat_new_scores(update_indices[i]) = "
<< flat_new_scores(update_indices[i]);
VLOG(2) << "flat_scores(shortlist_indices[j])) = "
<< flat_scores(shortlist_indices[j]);
if (flat_new_scores(update_indices[i]) >
flat_scores(shortlist_indices[j])) {
// Whenever we erase something from the shortlist, we need to
// update score_cutoff.
score_cutoff =
std::max(score_cutoff, flat_scores(shortlist_indices[j]));
updates.push_back(std::make_tuple(
shortlist_indices[j], flat_new_ids(update_indices[i]),
flat_new_scores(update_indices[i])));
if (flat_ids(shortlist_indices[j]) == -1) {
shortlist_size++;
}
j++;
} else {
// Whenever we fail to insert something into the shortlist, we need to
// update score_cutoff.
score_cutoff =
std::max(score_cutoff, flat_new_scores(update_indices[i]));
}
i++;
}
} else {
// Everything fits, no need to sort.
int j = 1;
for (int i = 0; i < num_updates; i++) {
if (flat_new_scores(i) < score_cutoff) {
continue;
}
while (j <= shortlist_max_size && flat_ids(j) != -1) {
j++;
}
if (j > shortlist_max_size) {
LOG(FATAL) << "Bug";
}
updates.push_back(
std::make_tuple(j, flat_new_ids(i), flat_new_scores(i)));
j++;
shortlist_size++;
}
}
updates.push_back(std::make_tuple(0, shortlist_size, score_cutoff));
Tensor* output_shortlist_ids = nullptr;
TensorShape shortlist_ids_shape;
shortlist_ids_shape.AddDim(updates.size());
OP_REQUIRES_OK(context, context->allocate_output(0, shortlist_ids_shape,
&output_shortlist_ids));
auto shortlist_ids_flat = output_shortlist_ids->tensor<int64, 1>();
Tensor* output_ids = nullptr;
TensorShape ids_shape;
ids_shape.AddDim(updates.size());
OP_REQUIRES_OK(context,
context->allocate_output(1, ids_shape, &output_ids));
auto output_ids_flat = output_ids->tensor<int64, 1>();
Tensor* output_scores = nullptr;
TensorShape scores_shape;
scores_shape.AddDim(updates.size());
OP_REQUIRES_OK(context,
context->allocate_output(2, scores_shape, &output_scores));
auto output_scores_flat = output_scores->tensor<float, 1>();
int i = 0;
for (const auto& update : updates) {
shortlist_ids_flat(i) = std::get<0>(update);
output_ids_flat(i) = std::get<1>(update);
output_scores_flat(i) = std::get<2>(update);
i++;
}
}
};
REGISTER_OP("TopNRemove")
.Input("ids: int64")
.Input("remove_ids: int64")
.Output("shortlist_ids: int64")
.Output("new_length: int64")
.Doc(R"doc(
Remove ids from a shortlist.
ids:= A 1-D int64 tensor containing the ids on the shortlist (except for
ids[0], which is the current size of the shortlist.
remove_ids:= A 1-D int64 tensor containing the ids to remove.
shortlist_ids:= A 1-D int64 tensor containing the shortlist entries that
need to be removed.
new_length:= A length 1 1-D int64 tensor containing the new length of the
shortlist.
)doc");
class TopNRemove : public OpKernel {
public:
explicit TopNRemove(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& ids = context->input(0);
const Tensor& remove_ids = context->input(1);
OP_REQUIRES(context, ids.shape().dims() == 1,
errors::InvalidArgument("ids should be one-dimensional"));
OP_REQUIRES(
context, remove_ids.shape().dims() == 1,
errors::InvalidArgument("remove_ids should be one-dimensional"));
const auto flat_ids = ids.unaligned_flat<int64>();
const auto flat_remove_ids = remove_ids.unaligned_flat<int64>();
const int num_to_remove = remove_ids.shape().dim_size(0);
const int shortlist_max_size = ids.shape().dim_size(0);
// First, turn remove_ids into a set for easy membership checking.
std::unordered_set<int> ids_to_remove(
flat_remove_ids.data(), flat_remove_ids.data() + num_to_remove);
std::vector<int64> updates;
int shortlist_size = std::max(0, static_cast<int>(flat_ids(0)));
for (int j = 1; j < shortlist_max_size; j++) {
if (ids_to_remove.find(flat_ids(j)) != ids_to_remove.end()) {
shortlist_size--;
updates.push_back(j);
}
}
Tensor* output_shortlist_ids = nullptr;
TensorShape shortlist_ids_shape;
shortlist_ids_shape.AddDim(updates.size());
OP_REQUIRES_OK(context, context->allocate_output(0, shortlist_ids_shape,
&output_shortlist_ids));
auto shortlist_ids_flat = output_shortlist_ids->tensor<int64, 1>();
std::copy(updates.begin(), updates.end(), shortlist_ids_flat.data());
Tensor* new_length = nullptr;
TensorShape new_length_shape;
new_length_shape.AddDim(1);
OP_REQUIRES_OK(context,
context->allocate_output(1, new_length_shape, &new_length));
new_length->tensor<int64, 1>()(0) = shortlist_size;
}
};
REGISTER_KERNEL_BUILDER(Name("TopNInsert").Device(DEVICE_CPU), TopNInsert);
REGISTER_KERNEL_BUILDER(Name("TopNRemove").Device(DEVICE_CPU), TopNRemove);
} // namespace tensorflow

View File

@ -21,4 +21,5 @@ from __future__ import print_function
from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.contrib.tensor_forest.python.ops import topn_ops
from tensorflow.contrib.tensor_forest.python.ops import training_ops

View File

@ -0,0 +1,62 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for TopN class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import tensorflow as tf
from tensorflow.python.framework import ops
TOPN_OPS_FILE = '_topn_ops.so'
_topn_ops = None
_ops_lock = threading.Lock()
ops.NoGradient('TopNInsert')
ops.NoGradient('TopNRemove')
@ops.RegisterShape('TopNInsert')
def Insert(unused_op):
"""Shape function for Insert Op."""
return [[None], [None], [None]]
@ops.RegisterShape('TopNRemove')
def Remove(unused_op):
"""Shape function for Remove Op."""
return [[None], [None]]
# Workaround for the fact that importing tensorflow imports contrib
# (even if a user isn't using this or any other contrib op), but
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
def Load():
"""Load the TopN ops library and return the loaded module."""
with _ops_lock:
global _topn_ops
if not _topn_ops:
ops_path = tf.resource_loader.get_path_to_datafile(TOPN_OPS_FILE)
tf.logging.info('data path: %s', ops_path)
_topn_ops = tf.load_op_library(ops_path)
assert _topn_ops, 'Could not load topn_ops.so'
return _topn_ops

View File

@ -0,0 +1,159 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A collection that allows repeated access to its top-scoring items."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python.ops import topn_ops
class TopN(object):
"""A collection that allows repeated access to its top-scoring items.
A TopN supports the following three operations:
1) insert(ids, scores). ids is a 1-d int64 Tensor and scores is a 1-d
float Tensor. scores[i] is the score associated with ids[i]. It is
totally fine to re-insert ids that have already been inserted into the
collection.
2) remove(ids)
3) ids, scores = get_best(n). scores will contain the n highest (most
positive) scores currently in the TopN, and ids their corresponding ids.
n is a 1-d int32 Tensor with shape (1).
TopN is implemented using a short-list of the top scoring items. At
construction time, the size of the short-list must be specified, and it
is an error to call GetBest(n) with an n greater than that size.
"""
def __init__(self, max_id, shortlist_size=100, name_prefix=''):
"""Creates a new TopN."""
self.ops = topn_ops.Load()
self.shortlist_size = shortlist_size
# id_to_score contains all the scores we are tracking.
self.id_to_score = tf.get_variable(
name=name_prefix + 'id_to_score',
dtype=tf.float32,
shape=[max_id],
initializer=tf.constant_initializer(tf.float32.min))
# sl_ids and sl_scores together satisfy four invariants:
# 1) If sl_ids[i] != -1, then
# id_to_score[sl_ids[i]] = sl_scores[i] >= sl_scores[0]
# 2) sl_ids[0] is the number of i > 0 for which sl_ids[i] != -1.
# 3) If id_to_score[i] > sl_scores[0], then
# sl_ids[j] = i for some j.
# 4) If sl_ids[i] == -1, then sl_scores[i] = tf.float32.min.
self.sl_ids = tf.get_variable(
name=name_prefix + 'shortlist_ids',
dtype=tf.int64,
shape=[shortlist_size + 1],
initializer=tf.constant_initializer(-1))
# Ideally, we would set self.sl_ids[0] = 0 here. But then it is hard
# to pass that control dependency to the other other Ops. Instead, we
# have insert, remove and get_best all deal with the fact that
# self.sl_ids[0] == -1 actually means the shortlist size is 0.
self.sl_scores = tf.get_variable(
name=name_prefix + 'shortlist_scores',
dtype=tf.float32,
shape=[shortlist_size + 1],
initializer=tf.constant_initializer(tf.float32.min))
# TopN keeps track of its internal data dependencies, so the user
# doesn't have to.
self.last_ops = []
def insert(self, ids, scores):
"""Insert the ids and scores into the TopN."""
with tf.control_dependencies(self.last_ops):
scatter_op = tf.scatter_update(self.id_to_score, ids, scores)
larger_scores = tf.greater(scores, self.sl_scores[0])
def shortlist_insert():
larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
larger_score_values = tf.boolean_mask(scores, larger_scores)
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores)
return tf.group(u1, u2)
# We only need to insert into the shortlist if there are any
# scores larger than the threshold.
cond_op = tf.cond(
tf.reduce_any(larger_scores), shortlist_insert, tf.no_op)
with tf.control_dependencies([cond_op]):
self.last_ops = [scatter_op, cond_op]
def remove(self, ids):
"""Remove the ids (and their associated scores) from the TopN."""
with tf.control_dependencies(self.last_ops):
scatter_op = tf.scatter_update(
self.id_to_score,
ids,
tf.ones_like(
ids, dtype=tf.float32) * tf.float32.min)
# We assume that removed ids are almost always in the shortlist,
# so it makes no sense to hide the Op behind a tf.cond
shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids,
ids)
u1 = tf.scatter_update(
self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]),
tf.concat(0, [new_length,
tf.ones_like(shortlist_ids_to_remove) * -1]))
u2 = tf.scatter_update(
self.sl_scores,
shortlist_ids_to_remove,
tf.float32.min * tf.ones_like(
shortlist_ids_to_remove, dtype=tf.float32))
self.last_ops = [scatter_op, u1, u2]
def get_best(self, n):
"""Return the indices and values of the n highest scores in the TopN."""
def refresh_shortlist():
"""Update the shortlist with the highest scores in id_to_score."""
new_scores, new_ids = tf.nn.top_k(self.id_to_score, self.shortlist_size)
smallest_new_score = tf.reduce_min(new_scores)
new_length = tf.reduce_sum(
tf.to_int32(tf.greater(new_scores, tf.float32.min)))
u1 = self.sl_ids.assign(
tf.to_int64(tf.concat(0, [[new_length], new_ids])))
u2 = self.sl_scores.assign(
tf.concat(0, [[smallest_new_score], new_scores]))
self.last_ops = [u1, u2]
return tf.group(u1, u2)
# We only need to refresh the shortlist if n is greater than the
# current shortlist size (which is stored in sl_ids[0]).
with tf.control_dependencies(self.last_ops):
cond_op = tf.cond(n > self.sl_ids[0], refresh_shortlist, tf.no_op)
with tf.control_dependencies([cond_op]):
topk_values, topk_indices = tf.nn.top_k(
self.sl_scores, tf.minimum(n, tf.to_int32(self.sl_ids[0])))
# topk_indices are the indices into the shortlist, we want to return
# the indices into id_to_score
gathered_indices = tf.gather(self.sl_ids, topk_indices)
return gathered_indices, topk_values
def get_and_remove_best(self, n):
# TODO(thomaswc): Replace this with a version of get_best where
# refresh_shortlist grabs the top n + shortlist_size.
top_ids, unused_top_vals = self.get_best(n)
remove_op = self.remove(top_ids)
return tf.identity(top_ids, control_inputs=remove_op)

View File

@ -0,0 +1,183 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for topn.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import topn
from tensorflow.contrib.tensor_forest.python.ops import topn_ops
from tensorflow.python.client import session
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class TopNOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
self.ops = topn_ops.Load()
def testInsertOpIntoEmptyShortlist(self):
with self.test_session():
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
[0, -1, -1, -1, -1, -1], # sl_ids
[-999, -999, -999, -999, -999, -999], # sl_scores
[5],
[33.0] # new id and score
)
self.assertAllEqual([1, 0], shortlist_ids.eval())
self.assertAllEqual([5, 1], new_ids.eval())
self.assertAllEqual([33.0, -999], new_scores.eval())
def testInsertOpIntoAlmostFullShortlist(self):
with self.test_session():
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
[4, 13, -1, 27, 99, 15], # sl_ids
[60.0, 87.0, -999, 65.0, 1000.0, 256.0], # sl_scores
[5],
[93.0] # new id and score
)
self.assertAllEqual([2, 0], shortlist_ids.eval())
self.assertAllEqual([5, 5], new_ids.eval())
# Shortlist still contains all known scores > 60.0
self.assertAllEqual([93.0, 60.0], new_scores.eval())
def testInsertOpIntoFullShortlist(self):
with self.test_session():
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
[5, 13, 44, 27, 99, 15], # sl_ids
[60.0, 87.0, 111.0, 65.0, 1000.0, 256.0], # sl_scores
[5],
[93.0] # new id and score
)
self.assertAllEqual([3, 0], shortlist_ids.eval())
self.assertAllEqual([5, 5], new_ids.eval())
# We removed a 65.0 from the list, so now we can only claim that
# it holds all scores > 65.0.
self.assertAllEqual([93.0, 65.0], new_scores.eval())
def testInsertOpHard(self):
with self.test_session():
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
[4, 13, -1, 27, 99, 15], # sl_ids
[60.0, 87.0, -999, 65.0, 1000.0, 256.0], # sl_scores
[5, 6, 7, 8, 9],
[61.0, 66.0, 90.0, 100.0, 2000.0] # new id and score
)
# Top 5 scores are: 2000.0, 1000.0, 256.0, 100.0, 90.0
self.assertAllEqual([2, 3, 1, 0], shortlist_ids.eval())
self.assertAllEqual([9, 8, 7, 5], new_ids.eval())
# 87.0 is the highest score we overwrote or didn't insert.
self.assertAllEqual([2000.0, 100.0, 90.0, 87.0], new_scores.eval())
def testRemoveSimple(self):
with self.test_session():
shortlist_ids, new_length = self.ops.top_n_remove(
[5, 100, 200, 300, 400, 500], [200, 400, 600])
self.assertAllEqual([2, 4], shortlist_ids.eval())
self.assertAllEqual([3], new_length.eval())
def testRemoveAllMissing(self):
with self.test_session():
shortlist_ids, new_length = self.ops.top_n_remove(
[5, 100, 200, 300, 400, 500], [1200, 1400, 600])
self.assertAllEqual([], shortlist_ids.eval())
self.assertAllEqual([5], new_length.eval())
def testRemoveAll(self):
with self.test_session():
shortlist_ids, new_length = self.ops.top_n_remove(
[5, 100, 200, 300, 400, 500],
[100, 200, 300, 400, 500],)
self.assertAllEqual([1, 2, 3, 4, 5], shortlist_ids.eval())
self.assertAllEqual([0], new_length.eval())
class TopNTest(test_util.TensorFlowTestCase):
def testSimple(self):
t = topn.TopN(1000, shortlist_size=10)
t.insert([1, 2, 3, 4, 5], [1.0, 2.0, 3.0, 4.0, 5.0])
t.remove([4, 5])
ids, vals = t.get_best(2)
with session.Session() as sess:
sess.run(tf.initialize_all_variables())
ids_v, vals_v = sess.run([ids, vals])
self.assertItemsEqual([2, 3], list(ids_v))
self.assertItemsEqual([2.0, 3.0], list(vals_v))
def testSimpler(self):
t = topn.TopN(1000, shortlist_size=10)
t.insert([1], [33.0])
ids, vals = t.get_best(1)
with session.Session() as sess:
sess.run(tf.initialize_all_variables())
ids_v, vals_v = sess.run([ids, vals])
self.assertListEqual([1], list(ids_v))
self.assertListEqual([33.0], list(vals_v))
def testLotsOfInsertsAscending(self):
t = topn.TopN(1000, shortlist_size=10)
for i in range(100):
t.insert([i], [float(i)])
ids, vals = t.get_best(5)
with session.Session() as sess:
sess.run(tf.initialize_all_variables())
ids_v, vals_v = sess.run([ids, vals])
self.assertItemsEqual([95, 96, 97, 98, 99], list(ids_v))
self.assertItemsEqual([95.0, 96.0, 97.0, 98.0, 99.0], list(vals_v))
def testLotsOfInsertsDescending(self):
t = topn.TopN(1000, shortlist_size=10)
for i in range(99, 1, -1):
t.insert([i], [float(i)])
ids, vals = t.get_best(5)
with session.Session() as sess:
sess.run(tf.initialize_all_variables())
ids_v, vals_v = sess.run([ids, vals])
self.assertItemsEqual([95, 96, 97, 98, 99], list(ids_v))
self.assertItemsEqual([95.0, 96.0, 97.0, 98.0, 99.0], list(vals_v))
def testRemoveNotInShortlist(self):
t = topn.TopN(1000, shortlist_size=10)
for i in range(20):
t.insert([i], [float(i)])
t.remove([4, 5])
ids, vals = t.get_best(2)
with session.Session() as sess:
sess.run(tf.initialize_all_variables())
ids_v, vals_v = sess.run([ids, vals])
self.assertItemsEqual([18.0, 19.0], list(vals_v))
self.assertItemsEqual([18, 19], list(ids_v))
def testNeedToRefreshShortlistInGetBest(self):
t = topn.TopN(1000, shortlist_size=10)
for i in range(20):
t.insert([i], [float(i)])
# Shortlist now has 10 .. 19
t.remove([11, 12, 13, 14, 15, 16, 17, 18, 19])
ids, vals = t.get_best(2)
with session.Session() as sess:
sess.run(tf.initialize_all_variables())
ids_v, vals_v = sess.run([ids, vals])
self.assertItemsEqual([9, 10], list(ids_v))
self.assertItemsEqual([9.0, 10.0], list(vals_v))
if __name__ == '__main__':
googletest.main()