Add TopN class.
Change: 130099876
This commit is contained in:
parent
5c29a24fd6
commit
55b44e625c
@ -125,15 +125,24 @@ tf_custom_op_library(
|
||||
deps = [":tree_utils"],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_topn_ops.so",
|
||||
srcs = [
|
||||
"core/ops/topn_ops.cc",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ops_lib",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/inference_ops.py",
|
||||
"python/ops/topn_ops.py",
|
||||
"python/ops/training_ops.py",
|
||||
],
|
||||
data = [
|
||||
"python/ops/_inference_ops.so",
|
||||
"python/ops/_topn_ops.so",
|
||||
"python/ops/_training_ops.so",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
@ -286,3 +295,26 @@ cc_test(
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "topn_py",
|
||||
srcs = ["python/topn.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":constants",
|
||||
":ops_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "topn_test",
|
||||
size = "small",
|
||||
srcs = ["python/topn_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":topn_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
265
tensorflow/contrib/tensor_forest/core/ops/topn_ops.cc
Normal file
265
tensorflow/contrib/tensor_forest/core/ops/topn_ops.cc
Normal file
@ -0,0 +1,265 @@
|
||||
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
// The three Ops used to implement a TopN structure: Insert, Remove, and
|
||||
// RefreshShortlist.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("TopNInsert")
|
||||
.Input("ids: int64")
|
||||
.Input("scores: float32")
|
||||
.Input("new_ids: int64")
|
||||
.Input("new_scores: float32")
|
||||
.Output("shortlist_ids: int64")
|
||||
.Output("update_ids: int64")
|
||||
.Output("update_scores: float32")
|
||||
|
||||
.Doc(R"doc(
|
||||
Outputs update Tensors for adding new_ids and new_scores to the shortlist.
|
||||
|
||||
ids:= A 1-D int64 tensor containing the ids on the shortlist (except for
|
||||
ids[0], which is the current size of the shortlist.
|
||||
scores:= A 1-D float32 tensor containing the scores on the shortlist.
|
||||
new_ids:= A 1-D int64 tensor containing the new ids to add to the shortlist.
|
||||
shortlist_ids:= A 1-D int64 tensor containing the ids of the shortlist entries
|
||||
to update. Intended to be used with
|
||||
tf.scatter_update(shortlist_scores, shortlist_ids, new_scores).
|
||||
update_ids:= A 1-D int64 tensor containing ...
|
||||
update_scores:= A 1-D float32 tensor containing ...
|
||||
)doc");
|
||||
|
||||
class TopNInsert : public OpKernel {
|
||||
public:
|
||||
explicit TopNInsert(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& ids = context->input(0);
|
||||
const Tensor& scores = context->input(1);
|
||||
const Tensor& new_ids = context->input(2);
|
||||
const Tensor& new_scores = context->input(3);
|
||||
|
||||
OP_REQUIRES(context, ids.shape().dims() == 1,
|
||||
errors::InvalidArgument("ids should be one-dimensional"));
|
||||
OP_REQUIRES(context, scores.shape().dims() == 1,
|
||||
errors::InvalidArgument("scores should be one-dimensional"));
|
||||
OP_REQUIRES(context, new_ids.shape().dims() == 1,
|
||||
errors::InvalidArgument("new_ids should be one-dimensional"));
|
||||
OP_REQUIRES(
|
||||
context, new_scores.shape().dims() == 1,
|
||||
errors::InvalidArgument("new_scores should be one-dimensional"));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, ids.shape().dim_size(0) == scores.shape().dim_size(0),
|
||||
errors::InvalidArgument("ids and scores should be the same length"));
|
||||
OP_REQUIRES(context,
|
||||
new_ids.shape().dim_size(0) == new_scores.shape().dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"new_ids and new_scores should be the same length"));
|
||||
|
||||
const auto flat_ids = ids.unaligned_flat<int64>();
|
||||
const auto flat_scores = scores.unaligned_flat<float>();
|
||||
const auto flat_new_ids = new_ids.unaligned_flat<int64>();
|
||||
const auto flat_new_scores = new_scores.unaligned_flat<float>();
|
||||
|
||||
const int num_updates = new_ids.shape().dim_size(0);
|
||||
const int shortlist_max_size = ids.shape().dim_size(0) - 1;
|
||||
int shortlist_size = std::max(0, static_cast<int>(flat_ids(0)));
|
||||
int overflow = shortlist_size + num_updates - shortlist_max_size;
|
||||
|
||||
std::vector<std::tuple<int64, int64, float>> updates;
|
||||
float score_cutoff = flat_scores(0);
|
||||
|
||||
if (overflow > 0) {
|
||||
// Sort the *highest* overflow updates
|
||||
std::vector<int> update_indices(num_updates);
|
||||
for (int i = 0; i < num_updates; i++) {
|
||||
update_indices[i] = i;
|
||||
}
|
||||
auto cmp = [&flat_new_scores](int a, int b) {
|
||||
return flat_new_scores(a) > flat_new_scores(b);
|
||||
};
|
||||
std::sort(update_indices.begin(), update_indices.end(), cmp);
|
||||
|
||||
// Sort the *lowest* overflow shortlist entries
|
||||
std::vector<int> shortlist_indices(shortlist_max_size + 1);
|
||||
std::iota(shortlist_indices.begin() + 1, shortlist_indices.end(), 1);
|
||||
auto cmp2 = [&flat_scores](int a, int b) {
|
||||
return flat_scores(a) < flat_scores(b);
|
||||
};
|
||||
std::sort(shortlist_indices.begin() + 1, shortlist_indices.end(), cmp2);
|
||||
|
||||
int i = 0; // Points into update_indices
|
||||
int j = 1; // Points into shortlist_indices
|
||||
while (i < num_updates && j <= shortlist_max_size) {
|
||||
VLOG(2) << "i = " << i;
|
||||
VLOG(2) << "j = " << j;
|
||||
VLOG(2) << "update_indices[i] = " << update_indices[i];
|
||||
VLOG(2) << "shortlist_indices[j] = " << shortlist_indices[j];
|
||||
VLOG(2) << "flat_new_scores(update_indices[i]) = "
|
||||
<< flat_new_scores(update_indices[i]);
|
||||
VLOG(2) << "flat_scores(shortlist_indices[j])) = "
|
||||
<< flat_scores(shortlist_indices[j]);
|
||||
if (flat_new_scores(update_indices[i]) >
|
||||
flat_scores(shortlist_indices[j])) {
|
||||
// Whenever we erase something from the shortlist, we need to
|
||||
// update score_cutoff.
|
||||
score_cutoff =
|
||||
std::max(score_cutoff, flat_scores(shortlist_indices[j]));
|
||||
updates.push_back(std::make_tuple(
|
||||
shortlist_indices[j], flat_new_ids(update_indices[i]),
|
||||
flat_new_scores(update_indices[i])));
|
||||
if (flat_ids(shortlist_indices[j]) == -1) {
|
||||
shortlist_size++;
|
||||
}
|
||||
j++;
|
||||
} else {
|
||||
// Whenever we fail to insert something into the shortlist, we need to
|
||||
// update score_cutoff.
|
||||
score_cutoff =
|
||||
std::max(score_cutoff, flat_new_scores(update_indices[i]));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
} else {
|
||||
// Everything fits, no need to sort.
|
||||
int j = 1;
|
||||
for (int i = 0; i < num_updates; i++) {
|
||||
if (flat_new_scores(i) < score_cutoff) {
|
||||
continue;
|
||||
}
|
||||
while (j <= shortlist_max_size && flat_ids(j) != -1) {
|
||||
j++;
|
||||
}
|
||||
if (j > shortlist_max_size) {
|
||||
LOG(FATAL) << "Bug";
|
||||
}
|
||||
updates.push_back(
|
||||
std::make_tuple(j, flat_new_ids(i), flat_new_scores(i)));
|
||||
j++;
|
||||
shortlist_size++;
|
||||
}
|
||||
}
|
||||
|
||||
updates.push_back(std::make_tuple(0, shortlist_size, score_cutoff));
|
||||
|
||||
Tensor* output_shortlist_ids = nullptr;
|
||||
TensorShape shortlist_ids_shape;
|
||||
shortlist_ids_shape.AddDim(updates.size());
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shortlist_ids_shape,
|
||||
&output_shortlist_ids));
|
||||
auto shortlist_ids_flat = output_shortlist_ids->tensor<int64, 1>();
|
||||
|
||||
Tensor* output_ids = nullptr;
|
||||
TensorShape ids_shape;
|
||||
ids_shape.AddDim(updates.size());
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, ids_shape, &output_ids));
|
||||
auto output_ids_flat = output_ids->tensor<int64, 1>();
|
||||
|
||||
Tensor* output_scores = nullptr;
|
||||
TensorShape scores_shape;
|
||||
scores_shape.AddDim(updates.size());
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(2, scores_shape, &output_scores));
|
||||
auto output_scores_flat = output_scores->tensor<float, 1>();
|
||||
|
||||
int i = 0;
|
||||
for (const auto& update : updates) {
|
||||
shortlist_ids_flat(i) = std::get<0>(update);
|
||||
output_ids_flat(i) = std::get<1>(update);
|
||||
output_scores_flat(i) = std::get<2>(update);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP("TopNRemove")
|
||||
.Input("ids: int64")
|
||||
.Input("remove_ids: int64")
|
||||
.Output("shortlist_ids: int64")
|
||||
.Output("new_length: int64")
|
||||
|
||||
.Doc(R"doc(
|
||||
Remove ids from a shortlist.
|
||||
|
||||
ids:= A 1-D int64 tensor containing the ids on the shortlist (except for
|
||||
ids[0], which is the current size of the shortlist.
|
||||
remove_ids:= A 1-D int64 tensor containing the ids to remove.
|
||||
shortlist_ids:= A 1-D int64 tensor containing the shortlist entries that
|
||||
need to be removed.
|
||||
new_length:= A length 1 1-D int64 tensor containing the new length of the
|
||||
shortlist.
|
||||
)doc");
|
||||
|
||||
class TopNRemove : public OpKernel {
|
||||
public:
|
||||
explicit TopNRemove(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& ids = context->input(0);
|
||||
const Tensor& remove_ids = context->input(1);
|
||||
|
||||
OP_REQUIRES(context, ids.shape().dims() == 1,
|
||||
errors::InvalidArgument("ids should be one-dimensional"));
|
||||
OP_REQUIRES(
|
||||
context, remove_ids.shape().dims() == 1,
|
||||
errors::InvalidArgument("remove_ids should be one-dimensional"));
|
||||
|
||||
const auto flat_ids = ids.unaligned_flat<int64>();
|
||||
const auto flat_remove_ids = remove_ids.unaligned_flat<int64>();
|
||||
|
||||
const int num_to_remove = remove_ids.shape().dim_size(0);
|
||||
const int shortlist_max_size = ids.shape().dim_size(0);
|
||||
|
||||
// First, turn remove_ids into a set for easy membership checking.
|
||||
std::unordered_set<int> ids_to_remove(
|
||||
flat_remove_ids.data(), flat_remove_ids.data() + num_to_remove);
|
||||
|
||||
std::vector<int64> updates;
|
||||
int shortlist_size = std::max(0, static_cast<int>(flat_ids(0)));
|
||||
for (int j = 1; j < shortlist_max_size; j++) {
|
||||
if (ids_to_remove.find(flat_ids(j)) != ids_to_remove.end()) {
|
||||
shortlist_size--;
|
||||
updates.push_back(j);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor* output_shortlist_ids = nullptr;
|
||||
TensorShape shortlist_ids_shape;
|
||||
shortlist_ids_shape.AddDim(updates.size());
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shortlist_ids_shape,
|
||||
&output_shortlist_ids));
|
||||
auto shortlist_ids_flat = output_shortlist_ids->tensor<int64, 1>();
|
||||
|
||||
std::copy(updates.begin(), updates.end(), shortlist_ids_flat.data());
|
||||
|
||||
Tensor* new_length = nullptr;
|
||||
TensorShape new_length_shape;
|
||||
new_length_shape.AddDim(1);
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, new_length_shape, &new_length));
|
||||
new_length->tensor<int64, 1>()(0) = shortlist_size;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TopNInsert").Device(DEVICE_CPU), TopNInsert);
|
||||
REGISTER_KERNEL_BUILDER(Name("TopNRemove").Device(DEVICE_CPU), TopNRemove);
|
||||
} // namespace tensorflow
|
||||
@ -21,4 +21,5 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.tensor_forest.python import constants
|
||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
|
||||
from tensorflow.contrib.tensor_forest.python.ops import topn_ops
|
||||
from tensorflow.contrib.tensor_forest.python.ops import training_ops
|
||||
|
||||
62
tensorflow/contrib/tensor_forest/python/ops/topn_ops.py
Normal file
62
tensorflow/contrib/tensor_forest/python/ops/topn_ops.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for TopN class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
TOPN_OPS_FILE = '_topn_ops.so'
|
||||
|
||||
_topn_ops = None
|
||||
_ops_lock = threading.Lock()
|
||||
|
||||
ops.NoGradient('TopNInsert')
|
||||
ops.NoGradient('TopNRemove')
|
||||
|
||||
|
||||
@ops.RegisterShape('TopNInsert')
|
||||
def Insert(unused_op):
|
||||
"""Shape function for Insert Op."""
|
||||
return [[None], [None], [None]]
|
||||
|
||||
|
||||
@ops.RegisterShape('TopNRemove')
|
||||
def Remove(unused_op):
|
||||
"""Shape function for Remove Op."""
|
||||
return [[None], [None]]
|
||||
|
||||
|
||||
# Workaround for the fact that importing tensorflow imports contrib
|
||||
# (even if a user isn't using this or any other contrib op), but
|
||||
# there's not yet any guarantee that the shared object exists.
|
||||
# In which case, "import tensorflow" will always crash, even for users that
|
||||
# never use contrib.
|
||||
def Load():
|
||||
"""Load the TopN ops library and return the loaded module."""
|
||||
with _ops_lock:
|
||||
global _topn_ops
|
||||
if not _topn_ops:
|
||||
ops_path = tf.resource_loader.get_path_to_datafile(TOPN_OPS_FILE)
|
||||
tf.logging.info('data path: %s', ops_path)
|
||||
_topn_ops = tf.load_op_library(ops_path)
|
||||
|
||||
assert _topn_ops, 'Could not load topn_ops.so'
|
||||
return _topn_ops
|
||||
159
tensorflow/contrib/tensor_forest/python/topn.py
Normal file
159
tensorflow/contrib/tensor_forest/python/topn.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A collection that allows repeated access to its top-scoring items."""
|
||||
# pylint: disable=g-bad-name
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.tensor_forest.python.ops import topn_ops
|
||||
|
||||
|
||||
class TopN(object):
|
||||
"""A collection that allows repeated access to its top-scoring items.
|
||||
|
||||
A TopN supports the following three operations:
|
||||
1) insert(ids, scores). ids is a 1-d int64 Tensor and scores is a 1-d
|
||||
float Tensor. scores[i] is the score associated with ids[i]. It is
|
||||
totally fine to re-insert ids that have already been inserted into the
|
||||
collection.
|
||||
|
||||
2) remove(ids)
|
||||
|
||||
3) ids, scores = get_best(n). scores will contain the n highest (most
|
||||
positive) scores currently in the TopN, and ids their corresponding ids.
|
||||
n is a 1-d int32 Tensor with shape (1).
|
||||
|
||||
TopN is implemented using a short-list of the top scoring items. At
|
||||
construction time, the size of the short-list must be specified, and it
|
||||
is an error to call GetBest(n) with an n greater than that size.
|
||||
"""
|
||||
|
||||
def __init__(self, max_id, shortlist_size=100, name_prefix=''):
|
||||
"""Creates a new TopN."""
|
||||
self.ops = topn_ops.Load()
|
||||
self.shortlist_size = shortlist_size
|
||||
# id_to_score contains all the scores we are tracking.
|
||||
self.id_to_score = tf.get_variable(
|
||||
name=name_prefix + 'id_to_score',
|
||||
dtype=tf.float32,
|
||||
shape=[max_id],
|
||||
initializer=tf.constant_initializer(tf.float32.min))
|
||||
# sl_ids and sl_scores together satisfy four invariants:
|
||||
# 1) If sl_ids[i] != -1, then
|
||||
# id_to_score[sl_ids[i]] = sl_scores[i] >= sl_scores[0]
|
||||
# 2) sl_ids[0] is the number of i > 0 for which sl_ids[i] != -1.
|
||||
# 3) If id_to_score[i] > sl_scores[0], then
|
||||
# sl_ids[j] = i for some j.
|
||||
# 4) If sl_ids[i] == -1, then sl_scores[i] = tf.float32.min.
|
||||
self.sl_ids = tf.get_variable(
|
||||
name=name_prefix + 'shortlist_ids',
|
||||
dtype=tf.int64,
|
||||
shape=[shortlist_size + 1],
|
||||
initializer=tf.constant_initializer(-1))
|
||||
# Ideally, we would set self.sl_ids[0] = 0 here. But then it is hard
|
||||
# to pass that control dependency to the other other Ops. Instead, we
|
||||
# have insert, remove and get_best all deal with the fact that
|
||||
# self.sl_ids[0] == -1 actually means the shortlist size is 0.
|
||||
self.sl_scores = tf.get_variable(
|
||||
name=name_prefix + 'shortlist_scores',
|
||||
dtype=tf.float32,
|
||||
shape=[shortlist_size + 1],
|
||||
initializer=tf.constant_initializer(tf.float32.min))
|
||||
# TopN keeps track of its internal data dependencies, so the user
|
||||
# doesn't have to.
|
||||
self.last_ops = []
|
||||
|
||||
def insert(self, ids, scores):
|
||||
"""Insert the ids and scores into the TopN."""
|
||||
with tf.control_dependencies(self.last_ops):
|
||||
scatter_op = tf.scatter_update(self.id_to_score, ids, scores)
|
||||
larger_scores = tf.greater(scores, self.sl_scores[0])
|
||||
|
||||
def shortlist_insert():
|
||||
larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
|
||||
larger_score_values = tf.boolean_mask(scores, larger_scores)
|
||||
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
|
||||
self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
|
||||
u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
|
||||
u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores)
|
||||
return tf.group(u1, u2)
|
||||
|
||||
# We only need to insert into the shortlist if there are any
|
||||
# scores larger than the threshold.
|
||||
cond_op = tf.cond(
|
||||
tf.reduce_any(larger_scores), shortlist_insert, tf.no_op)
|
||||
with tf.control_dependencies([cond_op]):
|
||||
self.last_ops = [scatter_op, cond_op]
|
||||
|
||||
def remove(self, ids):
|
||||
"""Remove the ids (and their associated scores) from the TopN."""
|
||||
with tf.control_dependencies(self.last_ops):
|
||||
scatter_op = tf.scatter_update(
|
||||
self.id_to_score,
|
||||
ids,
|
||||
tf.ones_like(
|
||||
ids, dtype=tf.float32) * tf.float32.min)
|
||||
# We assume that removed ids are almost always in the shortlist,
|
||||
# so it makes no sense to hide the Op behind a tf.cond
|
||||
shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids,
|
||||
ids)
|
||||
u1 = tf.scatter_update(
|
||||
self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]),
|
||||
tf.concat(0, [new_length,
|
||||
tf.ones_like(shortlist_ids_to_remove) * -1]))
|
||||
u2 = tf.scatter_update(
|
||||
self.sl_scores,
|
||||
shortlist_ids_to_remove,
|
||||
tf.float32.min * tf.ones_like(
|
||||
shortlist_ids_to_remove, dtype=tf.float32))
|
||||
self.last_ops = [scatter_op, u1, u2]
|
||||
|
||||
def get_best(self, n):
|
||||
"""Return the indices and values of the n highest scores in the TopN."""
|
||||
|
||||
def refresh_shortlist():
|
||||
"""Update the shortlist with the highest scores in id_to_score."""
|
||||
new_scores, new_ids = tf.nn.top_k(self.id_to_score, self.shortlist_size)
|
||||
smallest_new_score = tf.reduce_min(new_scores)
|
||||
new_length = tf.reduce_sum(
|
||||
tf.to_int32(tf.greater(new_scores, tf.float32.min)))
|
||||
u1 = self.sl_ids.assign(
|
||||
tf.to_int64(tf.concat(0, [[new_length], new_ids])))
|
||||
u2 = self.sl_scores.assign(
|
||||
tf.concat(0, [[smallest_new_score], new_scores]))
|
||||
self.last_ops = [u1, u2]
|
||||
return tf.group(u1, u2)
|
||||
|
||||
# We only need to refresh the shortlist if n is greater than the
|
||||
# current shortlist size (which is stored in sl_ids[0]).
|
||||
with tf.control_dependencies(self.last_ops):
|
||||
cond_op = tf.cond(n > self.sl_ids[0], refresh_shortlist, tf.no_op)
|
||||
with tf.control_dependencies([cond_op]):
|
||||
topk_values, topk_indices = tf.nn.top_k(
|
||||
self.sl_scores, tf.minimum(n, tf.to_int32(self.sl_ids[0])))
|
||||
# topk_indices are the indices into the shortlist, we want to return
|
||||
# the indices into id_to_score
|
||||
gathered_indices = tf.gather(self.sl_ids, topk_indices)
|
||||
return gathered_indices, topk_values
|
||||
|
||||
def get_and_remove_best(self, n):
|
||||
# TODO(thomaswc): Replace this with a version of get_best where
|
||||
# refresh_shortlist grabs the top n + shortlist_size.
|
||||
top_ids, unused_top_vals = self.get_best(n)
|
||||
remove_op = self.remove(top_ids)
|
||||
return tf.identity(top_ids, control_inputs=remove_op)
|
||||
183
tensorflow/contrib/tensor_forest/python/topn_test.py
Normal file
183
tensorflow/contrib/tensor_forest/python/topn_test.py
Normal file
@ -0,0 +1,183 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for topn.py."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.tensor_forest.python import topn
|
||||
from tensorflow.contrib.tensor_forest.python.ops import topn_ops
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class TopNOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.ops = topn_ops.Load()
|
||||
|
||||
def testInsertOpIntoEmptyShortlist(self):
|
||||
with self.test_session():
|
||||
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
|
||||
[0, -1, -1, -1, -1, -1], # sl_ids
|
||||
[-999, -999, -999, -999, -999, -999], # sl_scores
|
||||
[5],
|
||||
[33.0] # new id and score
|
||||
)
|
||||
self.assertAllEqual([1, 0], shortlist_ids.eval())
|
||||
self.assertAllEqual([5, 1], new_ids.eval())
|
||||
self.assertAllEqual([33.0, -999], new_scores.eval())
|
||||
|
||||
def testInsertOpIntoAlmostFullShortlist(self):
|
||||
with self.test_session():
|
||||
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
|
||||
[4, 13, -1, 27, 99, 15], # sl_ids
|
||||
[60.0, 87.0, -999, 65.0, 1000.0, 256.0], # sl_scores
|
||||
[5],
|
||||
[93.0] # new id and score
|
||||
)
|
||||
self.assertAllEqual([2, 0], shortlist_ids.eval())
|
||||
self.assertAllEqual([5, 5], new_ids.eval())
|
||||
# Shortlist still contains all known scores > 60.0
|
||||
self.assertAllEqual([93.0, 60.0], new_scores.eval())
|
||||
|
||||
def testInsertOpIntoFullShortlist(self):
|
||||
with self.test_session():
|
||||
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
|
||||
[5, 13, 44, 27, 99, 15], # sl_ids
|
||||
[60.0, 87.0, 111.0, 65.0, 1000.0, 256.0], # sl_scores
|
||||
[5],
|
||||
[93.0] # new id and score
|
||||
)
|
||||
self.assertAllEqual([3, 0], shortlist_ids.eval())
|
||||
self.assertAllEqual([5, 5], new_ids.eval())
|
||||
# We removed a 65.0 from the list, so now we can only claim that
|
||||
# it holds all scores > 65.0.
|
||||
self.assertAllEqual([93.0, 65.0], new_scores.eval())
|
||||
|
||||
def testInsertOpHard(self):
|
||||
with self.test_session():
|
||||
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
|
||||
[4, 13, -1, 27, 99, 15], # sl_ids
|
||||
[60.0, 87.0, -999, 65.0, 1000.0, 256.0], # sl_scores
|
||||
[5, 6, 7, 8, 9],
|
||||
[61.0, 66.0, 90.0, 100.0, 2000.0] # new id and score
|
||||
)
|
||||
# Top 5 scores are: 2000.0, 1000.0, 256.0, 100.0, 90.0
|
||||
self.assertAllEqual([2, 3, 1, 0], shortlist_ids.eval())
|
||||
self.assertAllEqual([9, 8, 7, 5], new_ids.eval())
|
||||
# 87.0 is the highest score we overwrote or didn't insert.
|
||||
self.assertAllEqual([2000.0, 100.0, 90.0, 87.0], new_scores.eval())
|
||||
|
||||
def testRemoveSimple(self):
|
||||
with self.test_session():
|
||||
shortlist_ids, new_length = self.ops.top_n_remove(
|
||||
[5, 100, 200, 300, 400, 500], [200, 400, 600])
|
||||
self.assertAllEqual([2, 4], shortlist_ids.eval())
|
||||
self.assertAllEqual([3], new_length.eval())
|
||||
|
||||
def testRemoveAllMissing(self):
|
||||
with self.test_session():
|
||||
shortlist_ids, new_length = self.ops.top_n_remove(
|
||||
[5, 100, 200, 300, 400, 500], [1200, 1400, 600])
|
||||
self.assertAllEqual([], shortlist_ids.eval())
|
||||
self.assertAllEqual([5], new_length.eval())
|
||||
|
||||
def testRemoveAll(self):
|
||||
with self.test_session():
|
||||
shortlist_ids, new_length = self.ops.top_n_remove(
|
||||
[5, 100, 200, 300, 400, 500],
|
||||
[100, 200, 300, 400, 500],)
|
||||
self.assertAllEqual([1, 2, 3, 4, 5], shortlist_ids.eval())
|
||||
self.assertAllEqual([0], new_length.eval())
|
||||
|
||||
|
||||
class TopNTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testSimple(self):
|
||||
t = topn.TopN(1000, shortlist_size=10)
|
||||
t.insert([1, 2, 3, 4, 5], [1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
t.remove([4, 5])
|
||||
ids, vals = t.get_best(2)
|
||||
with session.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
ids_v, vals_v = sess.run([ids, vals])
|
||||
self.assertItemsEqual([2, 3], list(ids_v))
|
||||
self.assertItemsEqual([2.0, 3.0], list(vals_v))
|
||||
|
||||
def testSimpler(self):
|
||||
t = topn.TopN(1000, shortlist_size=10)
|
||||
t.insert([1], [33.0])
|
||||
ids, vals = t.get_best(1)
|
||||
with session.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
ids_v, vals_v = sess.run([ids, vals])
|
||||
self.assertListEqual([1], list(ids_v))
|
||||
self.assertListEqual([33.0], list(vals_v))
|
||||
|
||||
def testLotsOfInsertsAscending(self):
|
||||
t = topn.TopN(1000, shortlist_size=10)
|
||||
for i in range(100):
|
||||
t.insert([i], [float(i)])
|
||||
ids, vals = t.get_best(5)
|
||||
with session.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
ids_v, vals_v = sess.run([ids, vals])
|
||||
self.assertItemsEqual([95, 96, 97, 98, 99], list(ids_v))
|
||||
self.assertItemsEqual([95.0, 96.0, 97.0, 98.0, 99.0], list(vals_v))
|
||||
|
||||
def testLotsOfInsertsDescending(self):
|
||||
t = topn.TopN(1000, shortlist_size=10)
|
||||
for i in range(99, 1, -1):
|
||||
t.insert([i], [float(i)])
|
||||
ids, vals = t.get_best(5)
|
||||
with session.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
ids_v, vals_v = sess.run([ids, vals])
|
||||
self.assertItemsEqual([95, 96, 97, 98, 99], list(ids_v))
|
||||
self.assertItemsEqual([95.0, 96.0, 97.0, 98.0, 99.0], list(vals_v))
|
||||
|
||||
def testRemoveNotInShortlist(self):
|
||||
t = topn.TopN(1000, shortlist_size=10)
|
||||
for i in range(20):
|
||||
t.insert([i], [float(i)])
|
||||
t.remove([4, 5])
|
||||
ids, vals = t.get_best(2)
|
||||
with session.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
ids_v, vals_v = sess.run([ids, vals])
|
||||
self.assertItemsEqual([18.0, 19.0], list(vals_v))
|
||||
self.assertItemsEqual([18, 19], list(ids_v))
|
||||
|
||||
def testNeedToRefreshShortlistInGetBest(self):
|
||||
t = topn.TopN(1000, shortlist_size=10)
|
||||
for i in range(20):
|
||||
t.insert([i], [float(i)])
|
||||
# Shortlist now has 10 .. 19
|
||||
t.remove([11, 12, 13, 14, 15, 16, 17, 18, 19])
|
||||
ids, vals = t.get_best(2)
|
||||
with session.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
ids_v, vals_v = sess.run([ids, vals])
|
||||
self.assertItemsEqual([9, 10], list(ids_v))
|
||||
self.assertItemsEqual([9.0, 10.0], list(vals_v))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user