Refactor some of TensorForest V4 to make the tree model valid during training time, instead of only after FinalizeTreeOp.

PiperOrigin-RevId: 161663317
This commit is contained in:
A. Unique TensorFlower 2017-07-12 07:37:31 -07:00 committed by TensorFlower Gardener
parent eb1fe50da4
commit 786bf6cd65
17 changed files with 374 additions and 172 deletions

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <functional>
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
@ -26,6 +27,7 @@
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace tensorforest {
@ -46,7 +48,7 @@ class CreateTreeVariableOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
errors::InvalidArgument("Tree config must be a scalar."));
auto* result = new DecisionTreeResource();
auto* result = new DecisionTreeResource(param_proto_);
if (!ParseProtoUnlimited(result->mutable_decision_tree(),
tree_config_t->scalar<string>()())) {
result->Unref();
@ -142,6 +144,16 @@ class TreeSizeOp : public OpKernel {
}
};
void TraverseTree(const DecisionTreeResource* tree_resource,
const std::unique_ptr<TensorDataSet>& data, int32 start,
int32 end,
const std::function<void(int32, int32)>& set_leaf_id) {
for (int i = start; i < end; ++i) {
const int32 id = tree_resource->TraverseTree(data, i, nullptr);
set_leaf_id(i, id);
}
}
// Op for tree inference.
class TreePredictionsV4Op : public OpKernel {
public:
@ -176,22 +188,49 @@ class TreePredictionsV4Op : public OpKernel {
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = data_set_->NumItems();
const int32 num_outputs = param_proto_.num_outputs();
Tensor* output_predictions = nullptr;
TensorShape output_shape;
output_shape.AddDim(data_set_->NumItems());
output_shape.AddDim(param_proto_.num_outputs());
output_shape.AddDim(num_data);
output_shape.AddDim(num_outputs);
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
&output_predictions));
TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>();
auto out = output_predictions->tensor<float, 2>();
for (int i = 0; i < data_set_->NumItems(); ++i) {
const int32 leaf_id =
decision_tree_resource->TraverseTree(data_set_, i, nullptr);
const decision_trees::Leaf& leaf =
decision_tree_resource->get_leaf(leaf_id);
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500;
auto traverse = [this, &out, decision_tree_resource, num_data](int64 start,
int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
static_cast<int32>(end),
std::bind(&TreePredictionsV4Op::set_output_value, this,
std::placeholders::_1, std::placeholders::_2,
decision_tree_resource, &out));
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
traverse);
}
void set_output_value(int32 i, int32 id,
DecisionTreeResource* decision_tree_resource,
TTypes<float, 2>::Tensor* out) {
const decision_trees::Leaf& leaf = decision_tree_resource->get_leaf(id);
float sum = 0;
for (int j = 0; j < param_proto_.num_outputs(); ++j) {
const float count = model_op_->GetOutputValue(leaf, j);
(*out)(i, j) = count;
sum += count;
}
if (!param_proto_.is_regression() && sum > 0 && sum != 1) {
for (int j = 0; j < param_proto_.num_outputs(); ++j) {
const float count = model_op_->GetOutputValue(leaf, j);
out(i, j) = count;
(*out)(i, j) /= sum;
}
}
}
@ -203,6 +242,122 @@ class TreePredictionsV4Op : public OpKernel {
TensorForestParams param_proto_;
};
// Outputs leaf ids for the given examples.
class TraverseTreeV4Op : public OpKernel {
public:
explicit TraverseTreeV4Op(OpKernelConstruction* context) : OpKernel(context) {
string serialized_params;
OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
ParseProtoUnlimited(&param_proto_, serialized_params);
string serialized_proto;
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
input_spec_.ParseFromString(serialized_proto);
data_set_ =
std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
}
void Compute(OpKernelContext* context) override {
const Tensor& input_data = context->input(1);
const Tensor& sparse_input_indices = context->input(2);
const Tensor& sparse_input_values = context->input(3);
const Tensor& sparse_input_shape = context->input(4);
data_set_->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = data_set_->NumItems();
Tensor* output_predictions = nullptr;
TensorShape output_shape;
output_shape.AddDim(num_data);
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
&output_predictions));
auto leaf_ids = output_predictions->tensor<int32, 1>();
auto set_leaf_ids = [&leaf_ids](int32 i, int32 id) { leaf_ids(i) = id; };
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500;
auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
static_cast<int32>(end), set_leaf_ids);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
traverse);
}
private:
tensorforest::TensorForestDataSpec input_spec_;
std::unique_ptr<TensorDataSet> data_set_;
TensorForestParams param_proto_;
};
// Update the given leaf models using the batch of labels.
class UpdateModelV4Op : public OpKernel {
public:
explicit UpdateModelV4Op(OpKernelConstruction* context) : OpKernel(context) {
string serialized_params;
OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
ParseProtoUnlimited(&param_proto_, serialized_params);
model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
}
void Compute(OpKernelContext* context) override {
const Tensor& leaf_ids = context->input(1);
const Tensor& input_labels = context->input(2);
const Tensor& input_weights = context->input(3);
DecisionTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = input_labels.shape().dim_size(0);
const int32 label_dim =
input_labels.shape().dims() <= 1
? 0
: static_cast<int>(input_labels.shape().dim_size(1));
const int32 num_targets =
param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
TensorInputTarget target(input_labels, input_weights, num_targets);
// TODO(gilberth): Make this thread safe and multi-thread.
UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
}
void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
int32 start, int32 end,
DecisionTreeResource* decision_tree_resource) {
const auto leaves = leaf_ids.unaligned_flat<int32>();
for (int i = start; i < end; ++i) {
model_op_->UpdateModel(
decision_tree_resource->get_mutable_tree_node(leaves(i))
->mutable_leaf(),
&target, i);
}
}
private:
std::unique_ptr<LeafModelOperator> model_op_;
TensorForestParams param_proto_;
};
// Op for getting feature usage counts.
class FeatureUsageCountsOp : public OpKernel {
public:
@ -286,8 +441,14 @@ REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), TreeSizeOp);
REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU),
TreePredictionsV4Op);
REGISTER_KERNEL_BUILDER(Name("TraverseTreeV4").Device(DEVICE_CPU),
TraverseTreeV4Op);
REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU),
FeatureUsageCountsOp);
REGISTER_KERNEL_BUILDER(Name("UpdateModelV4").Device(DEVICE_CPU),
UpdateModelV4Op);
} // namespace tensorforest
} // namespace tensorflow

View File

@ -64,6 +64,33 @@ TEST(ModelOpsTest, TreePredictionsV4_ShapeFn) {
INFER_OK(op, "?;?;?;?;[10,11]", "[?,?]");
}
TEST(ModelOpsTest, TraverseTreeV4_ShapeFn) {
ShapeInferenceTestOp op("TraverseTreeV4");
TF_ASSERT_OK(NodeDefBuilder("test", "TraverseTreeV4")
.Input("a", 0, DT_RESOURCE)
.Input("b", 1, DT_FLOAT)
.Input("c", 2, DT_INT64)
.Input("d", 3, DT_FLOAT)
.Input("e", 5, DT_INT64)
.Attr("input_spec", "")
.Attr("params", "")
.Finalize(&op.node_def));
// num_points = 2, sparse shape not known
INFER_OK(op, "?;[2,3];?;?;?", "[d1_0]");
// num_points = 2, sparse and dense shape rank known and > 1
INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0]");
// num_points = 2, sparse shape rank known and > 1
INFER_OK(op, "?;?;?;?;[10,11]", "[?]");
}
TEST(ModelOpsTest, UpdateModelV4_ShapeFn) {
ShapeInferenceTestOp op("UpdateModelV4");
INFER_OK(op, "[1];?;?;?", "");
}
TEST(ModelOpsTest, FeatureUsageCounts_ShapeFn) {
ShapeInferenceTestOp op("FeatureUsageCounts");
INFER_OK(op, "[1]", "[?]");

View File

@ -141,18 +141,6 @@ class FertileStatsDeserializeOp : public OpKernel {
TensorForestParams param_proto_;
};
void TraverseTree(const DecisionTreeResource* tree_resource,
const std::unique_ptr<TensorDataSet>& data, int32 start,
int32 end, std::vector<int32>* leaf_ids,
std::vector<int32>* leaf_depths) {
for (int i = start; i < end; ++i) {
int32 depth;
const int32 leaf_id = tree_resource->TraverseTree(data, i, &depth);
(*leaf_ids)[i] = leaf_id;
(*leaf_depths)[i] = depth;
}
}
// Try to update a leaf's stats by acquiring its lock. If it can't be
// acquired, put it in a waiting queue to come back to later and try the next
// one. Once all leaf_ids have been visited, cycle through the waiting ids
@ -160,28 +148,27 @@ void TraverseTree(const DecisionTreeResource* tree_resource,
void UpdateStats(FertileStatsResource* fertile_stats_resource,
const std::unique_ptr<TensorDataSet>& data,
const TensorInputTarget& target, int num_targets,
const std::vector<int32>& leaf_ids,
const std::vector<int32>& leaf_depths,
const Tensor& leaf_ids_tensor,
std::unordered_map<int32, std::unique_ptr<mutex>>* locks,
mutex* set_lock, int32 start, int32 end,
std::unordered_set<int32>* ready_to_split) {
const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
// Stores leaf_id, leaf_depth, example_id for examples that are waiting
// on another to finish.
std::queue<std::tuple<int32, int32, int32>> waiting;
std::queue<std::tuple<int32, int32>> waiting;
int32 i = start;
while (i < end || !waiting.empty()) {
int32 leaf_id;
int32 leaf_depth;
int32 example_id;
bool was_waiting = false;
if (i >= end) {
std::tie(leaf_id, leaf_depth, example_id) = waiting.front();
std::tie(leaf_id, example_id) = waiting.front();
waiting.pop();
was_waiting = true;
} else {
leaf_id = leaf_ids[i];
leaf_depth = leaf_depths[i];
leaf_id = leaf_ids(i);
example_id = i;
++i;
}
@ -190,14 +177,14 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
leaf_lock->lock();
} else {
if (!leaf_lock->try_lock()) {
waiting.emplace(leaf_id, leaf_depth, example_id);
waiting.emplace(leaf_id, example_id);
continue;
}
}
bool is_finished;
fertile_stats_resource->AddExampleToStatsAndInitialize(
data, &target, {example_id}, leaf_id, leaf_depth, &is_finished);
data, &target, {example_id}, leaf_id, &is_finished);
leaf_lock->unlock();
if (is_finished) {
set_lock->lock();
@ -214,8 +201,8 @@ void UpdateStatsCollated(
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
int num_targets,
const std::unordered_map<int32, std::vector<int>>& leaf_examples,
const std::vector<int32>& leaf_depths, mutex* set_lock, int32 start,
int32 end, std::unordered_set<int32>* ready_to_split) {
mutex* set_lock, int32 start, int32 end,
std::unordered_set<int32>* ready_to_split) {
auto it = leaf_examples.begin();
std::advance(it, start);
auto end_it = leaf_examples.begin();
@ -224,8 +211,7 @@ void UpdateStatsCollated(
int32 leaf_id = it->first;
bool is_finished;
fertile_stats_resource->AddExampleToStatsAndInitialize(
data, &target, it->second, leaf_id, leaf_depths[it->second[0]],
&is_finished);
data, &target, it->second, leaf_id, &is_finished);
if (is_finished) {
set_lock->lock();
ready_to_split->insert(leaf_id);
@ -261,6 +247,7 @@ class ProcessInputOp : public OpKernel {
const Tensor& sparse_input_shape = context->input(5);
const Tensor& input_labels = context->input(6);
const Tensor& input_weights = context->input(7);
const Tensor& leaf_ids_tensor = context->input(8);
data_set_->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
@ -281,22 +268,7 @@ class ProcessInputOp : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
// First find the leaf ids for each example.
std::vector<int32> leaf_ids(num_data);
// The depth of the leaf for example i.
std::vector<int32> leaf_depths(num_data);
const int64 costPerTraverse = 500;
auto traverse = [this, &leaf_ids, &leaf_depths, tree_resource, num_data](
int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(tree_resource, data_set_, static_cast<int32>(start),
static_cast<int32>(end), &leaf_ids, &leaf_depths);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
traverse);
const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
// Create one mutex per leaf. We need to protect access to leaf pointers,
// so instead of grouping examples by leaf, we spread examples out among
@ -306,10 +278,11 @@ class ProcessInputOp : public OpKernel {
std::unordered_map<int32, std::vector<int>> leaf_examples;
if (param_proto_.collate_examples()) {
for (int i = 0; i < num_data; ++i) {
leaf_examples[leaf_ids[i]].push_back(i);
leaf_examples[leaf_ids(i)].push_back(i);
}
} else {
for (const int32 id : leaf_ids) {
for (int i = 0; i < num_data; ++i) {
const int32 id = leaf_ids(i);
if (FindOrNull(locks, id) == nullptr) {
// TODO(gilberth): Consider using a memory pool for these.
locks[id] = std::unique_ptr<mutex>(new mutex);
@ -335,27 +308,26 @@ class ProcessInputOp : public OpKernel {
// from a digits run on local desktop. Heuristics might be necessary
// if it really matters that much.
const int64 costPerUpdate = 1000;
auto update = [this, &target, &leaf_ids, &leaf_depths, &num_targets,
auto update = [this, &target, &leaf_ids_tensor, &num_targets,
fertile_stats_resource, &locks, &set_lock, &ready_to_split,
num_data](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
UpdateStats(fertile_stats_resource, data_set_, target, num_targets,
leaf_ids, leaf_depths, &locks, &set_lock,
static_cast<int32>(start), static_cast<int32>(end),
&ready_to_split);
leaf_ids_tensor, &locks, &set_lock, static_cast<int32>(start),
static_cast<int32>(end), &ready_to_split);
};
auto update_collated = [this, &target, &leaf_ids, &num_targets,
&leaf_depths, fertile_stats_resource, tree_resource,
&leaf_examples, &set_lock, &ready_to_split,
auto update_collated = [this, &target, &num_targets, fertile_stats_resource,
tree_resource, &leaf_examples, &set_lock,
&ready_to_split,
num_leaves](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_leaves);
UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_,
target, num_targets, leaf_examples, leaf_depths,
&set_lock, static_cast<int32>(start),
static_cast<int32>(end), &ready_to_split);
target, num_targets, leaf_examples, &set_lock,
static_cast<int32>(start), static_cast<int32>(end),
&ready_to_split);
};
if (param_proto_.collate_examples()) {
@ -411,7 +383,8 @@ class GrowTreeOp : public OpKernel {
const int32 num_nodes =
static_cast<int32>(finished_nodes.shape().dim_size(0));
// TODO(gilberth): distribute this work over a number of threads.
// This op takes so little of the time for one batch that it isn't worth
// threading this.
for (int i = 0;
i < num_nodes &&
tree_resource->decision_tree().decision_tree().nodes_size() <
@ -420,16 +393,14 @@ class GrowTreeOp : public OpKernel {
const int32 node = finished(i);
std::unique_ptr<SplitCandidate> best(new SplitCandidate);
int32 parent_depth;
// TODO(gilberth): Pushing these to an output would allow the complete
// decoupling of tree from resource.
bool found =
fertile_stats_resource->BestSplit(node, best.get(), &parent_depth);
if (found) {
std::vector<int32> new_children;
tree_resource->SplitNode(node, best.get(), &new_children);
fertile_stats_resource->Allocate(parent_depth, new_children);
fertile_stats_resource->set_leaf_stat(best->left_stats(),
new_children[0]);
fertile_stats_resource->set_leaf_stat(best->right_stats(),
new_children[1]);
// We are done with best, so it is now safe to clear node.
fertile_stats_resource->Clear(node);
CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false);
@ -444,20 +415,17 @@ class GrowTreeOp : public OpKernel {
TensorForestParams param_proto_;
};
void FinalizeLeaf(const LeafStat& leaf_stats, bool is_regression,
bool drop_final_class,
void FinalizeLeaf(bool is_regression, bool drop_final_class,
const std::unique_ptr<LeafModelOperator>& leaf_op,
decision_trees::Leaf* leaf) {
leaf_op->ExportModel(leaf_stats, leaf);
// TODO(thomaswc): Move the rest of this into ExportModel.
// regression models are already stored in leaf in normalized form.
if (is_regression) {
return;
}
float sum = leaf_stats.weight_sum();
// TODO(gilberth): Calculate the leaf's sum.
float sum = 0;
LOG(FATAL) << "FinalizeTreeOp is disabled for now.";
if (sum <= 0.0) {
LOG(WARNING) << "Leaf with sum " << sum << " has stats "
<< leaf->ShortDebugString();
@ -517,8 +485,7 @@ class FinalizeTreeOp : public OpKernel {
->mutable_decision_tree()
->mutable_nodes(i);
if (node->has_leaf()) {
const auto& leaf_stats = fertile_stats_resource->leaf_stat(i);
FinalizeLeaf(leaf_stats, param_proto_.is_regression(),
FinalizeLeaf(param_proto_.is_regression(),
param_proto_.drop_final_class(), model_op_,
node->mutable_leaf());
}

View File

@ -45,7 +45,7 @@ TEST(StatsOpsTest, GrowTreeV4_ShapeFn) {
TEST(StatsOpsTest, ProcessInputV4_ShapeFn) {
ShapeInferenceTestOp op("ProcessInputV4");
INFER_OK(op, "[1];[1];?;?;?;?;?;?", "[?]");
INFER_OK(op, "[1];[1];?;?;?;?;?;?;?", "[?]");
}
TEST(StatsOpsTest, FinalizeTree_ShapeFn) {

View File

@ -18,6 +18,7 @@ namespace tensorflow {
namespace tensorforest {
using decision_trees::DecisionTree;
using decision_trees::Leaf;
using decision_trees::TreeNode;
int32 DecisionTreeResource::TraverseTree(
@ -51,13 +52,15 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
new_children->push_back(newid);
TreeNode* new_left = tree->add_nodes();
new_left->mutable_node_id()->set_value(newid++);
new_left->mutable_leaf();
Leaf* left_leaf = new_left->mutable_leaf();
model_op_->ExportModel(best->left_stats(), left_leaf);
// right
new_children->push_back(newid);
TreeNode* new_right = tree->add_nodes();
new_right->mutable_node_id()->set_value(newid);
new_right->mutable_leaf();
Leaf* right_leaf = new_right->mutable_leaf();
model_op_->ExportModel(best->right_stats(), right_leaf);
node->clear_leaf();
node->mutable_binary_node()->Swap(best->mutable_split());
@ -72,7 +75,7 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
void DecisionTreeResource::MaybeInitialize() {
DecisionTree* tree = decision_tree_->mutable_decision_tree();
if (tree->nodes_size() == 0) {
tree->add_nodes()->mutable_leaf();
model_op_->InitModel(tree->add_nodes()->mutable_leaf());
} else if (node_evaluators_.empty()) { // reconstruct evaluators
for (const auto& node : tree->nodes()) {
if (node.has_leaf()) {

View File

@ -31,8 +31,10 @@ namespace tensorforest {
class DecisionTreeResource : public ResourceBase {
public:
// Constructor.
explicit DecisionTreeResource()
: decision_tree_(new decision_trees::Model()) {}
explicit DecisionTreeResource(const TensorForestParams& params)
: params_(params), decision_tree_(new decision_trees::Model()) {
model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
}
string DebugString() override {
return strings::StrCat("DecisionTree[size=",
@ -79,7 +81,9 @@ class DecisionTreeResource : public ResourceBase {
private:
mutex mu_;
const TensorForestParams params_;
std::unique_ptr<decision_trees::Model> decision_tree_;
std::shared_ptr<LeafModelOperator> model_op_;
std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_;
};

View File

@ -20,14 +20,8 @@ namespace tensorflow {
namespace tensorforest {
void FertileStatsResource::AddExampleToStatsAndInitialize(
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, const std::vector<int>& examples,
int32 node_id, int32 node_depth, bool* is_finished) {
// Set leaf's counts for calculating probabilities.
for (int example : examples) {
model_op_->UpdateModel(&leaf_stats_[node_id], target, example);
}
const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
const std::vector<int>& examples, int32 node_id, bool* is_finished) {
// Update stats or initialize if needed.
if (collection_op_->IsInitialized(node_id)) {
collection_op_->AddExample(input_data, target, examples, node_id);
@ -47,8 +41,6 @@ void FertileStatsResource::AddExampleToStatsAndInitialize(
}
void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) {
leaf_stats_[node_id] = LeafStat();
model_op_->InitModel(&leaf_stats_[node_id]);
collection_op_->InitializeSlot(node_id, depth);
}
@ -62,7 +54,6 @@ void FertileStatsResource::Allocate(int32 parent_depth,
void FertileStatsResource::Clear(int32 node) {
collection_op_->ClearSlot(node);
leaf_stats_.erase(node);
}
bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
@ -71,27 +62,16 @@ bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
}
void FertileStatsResource::MaybeInitialize() {
if (leaf_stats_.empty()) {
AllocateNode(0, 0);
}
collection_op_->MaybeInitialize();
}
void FertileStatsResource::ExtractFromProto(const FertileStats& stats) {
collection_op_ =
SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_);
collection_op_->ExtractFromProto(stats);
for (int i = 0; i < stats.node_to_slot_size(); ++i) {
const auto& slot = stats.node_to_slot(i);
leaf_stats_[slot.node_id()] = slot.leaf_stats();
}
}
void FertileStatsResource::PackToProto(FertileStats* stats) const {
for (const auto& entry : leaf_stats_) {
auto* slot = stats->add_node_to_slot();
*slot->mutable_leaf_stats() = entry.second;
slot->set_node_id(entry.first);
}
collection_op_->PackToProto(stats);
}
} // namespace tensorforest

View File

@ -51,7 +51,6 @@ class FertileStatsResource : public ResourceBase {
// Resets the resource and frees the proto.
// Caller needs to hold the mutex lock while calling this.
void Reset() {
leaf_stats_.clear();
}
// Reset the stats for a node, but leave the leaf_stats intact.
@ -71,7 +70,7 @@ class FertileStatsResource : public ResourceBase {
void AddExampleToStatsAndInitialize(
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, const std::vector<int>& examples,
int32 node_id, int32 node_depth, bool* is_finished);
int32 node_id, bool* is_finished);
// Allocate a fertile slot for each ready node, then new children up to
// max_fertile_nodes_.
@ -85,19 +84,11 @@ class FertileStatsResource : public ResourceBase {
// was found.
bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth);
const LeafStat& leaf_stat(int32 node_id) {
return leaf_stats_[node_id];
}
void set_leaf_stat(const LeafStat& stat, int32 node_id) {
leaf_stats_[node_id] = stat;
}
private:
mutex mu_;
std::shared_ptr<LeafModelOperator> model_op_;
std::unique_ptr<SplitCollectionOperator> collection_op_;
std::unordered_map<int32, LeafStat> leaf_stats_;
const TensorForestParams params_;
void AllocateNode(int32 node_id, int32 depth);

View File

@ -17,6 +17,8 @@
namespace tensorflow {
namespace tensorforest {
using decision_trees::Leaf;
std::unique_ptr<LeafModelOperator>
LeafModelOperatorFactory::CreateLeafModelOperator(
const TensorForestParams& params) {
@ -50,24 +52,21 @@ float DenseClassificationLeafModelOperator::GetOutputValue(
}
void DenseClassificationLeafModelOperator::UpdateModel(
LeafStat* leaf, const InputTarget* target,
int example) const {
Leaf* leaf, const InputTarget* target, int example) const {
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
QCHECK_LT(int_label, params_.num_outputs())
<< "Got label greater than indicated number of classes. Is "
"params.num_classes set correctly?";
QCHECK_GE(int_label, 0);
auto* val = leaf->mutable_classification()->mutable_dense_counts()
->mutable_value(int_label);
auto* val = leaf->mutable_vector()->mutable_value(int_label);
float weight = target->GetTargetWeight(example);
val->set_float_value(val->float_value() + weight);
leaf->set_weight_sum(leaf->weight_sum() + weight);
}
void DenseClassificationLeafModelOperator::InitModel(
LeafStat* leaf) const {
void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
for (int i = 0; i < params_.num_outputs(); ++i) {
leaf->mutable_classification()->mutable_dense_counts()->add_value();
leaf->mutable_vector()->add_value();
}
}
@ -88,17 +87,15 @@ float SparseClassificationLeafModelOperator::GetOutputValue(
}
void SparseClassificationLeafModelOperator::UpdateModel(
LeafStat* leaf, const InputTarget* target,
int example) const {
Leaf* leaf, const InputTarget* target, int example) const {
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
QCHECK_LT(int_label, params_.num_outputs())
<< "Got label greater than indicated number of classes. Is "
"params.num_classes set correctly?";
QCHECK_GE(int_label, 0);
const float weight = target->GetTargetWeight(example);
leaf->set_weight_sum(leaf->weight_sum() + weight);
auto value_map = leaf->mutable_classification()->mutable_sparse_counts()
->mutable_sparse_value();
auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
auto it = value_map->find(int_label);
if (it == value_map->end()) {
(*value_map)[int_label].set_float_value(weight);
@ -123,8 +120,8 @@ float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
}
void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
LeafStat* leaf, const InputTarget* target, int example) const {
if (leaf->classification().has_dense_counts()) {
Leaf* leaf, const InputTarget* target, int example) const {
if (leaf->has_vector()) {
return dense_->UpdateModel(leaf, target, example);
} else {
return sparse_->UpdateModel(leaf, target, example);
@ -146,15 +143,15 @@ float RegressionLeafModelOperator::GetOutputValue(
return leaf.vector().value(o).float_value();
}
void RegressionLeafModelOperator::InitModel(
LeafStat* leaf) const {
void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
for (int i = 0; i < params_.num_outputs(); ++i) {
leaf->mutable_regression()->mutable_mean_output()->add_value();
leaf->mutable_vector()->add_value();
}
}
void RegressionLeafModelOperator::ExportModel(
const LeafStat& stat, decision_trees::Leaf* leaf) const {
leaf->clear_vector();
for (int i = 0; i < params_.num_outputs(); ++i) {
const float new_val =
stat.regression().mean_output().value(i).float_value() /

View File

@ -42,12 +42,11 @@ class LeafModelOperator {
int32 o) const = 0;
// Update the given Leaf's model with the given example.
virtual void UpdateModel(LeafStat* leaf,
const InputTarget* target,
int example) const = 0;
virtual void UpdateModel(decision_trees::Leaf* leaf,
const InputTarget* target, int example) const = 0;
// Initialize an empty Leaf model.
virtual void InitModel(LeafStat* leaf) const = 0;
virtual void InitModel(decision_trees::Leaf* leaf) const = 0;
virtual void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const = 0;
@ -65,10 +64,10 @@ class DenseClassificationLeafModelOperator : public LeafModelOperator {
float GetOutputValue(const decision_trees::Leaf& leaf,
int32 o) const override;
void UpdateModel(LeafStat* leaf, const InputTarget* target,
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override;
void InitModel(LeafStat* leaf) const override;
void InitModel(decision_trees::Leaf* leaf) const override;
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;
@ -84,10 +83,10 @@ class SparseClassificationLeafModelOperator : public LeafModelOperator {
float GetOutputValue(const decision_trees::Leaf& leaf,
int32 o) const override;
void UpdateModel(LeafStat* leaf, const InputTarget* target,
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override;
void InitModel(LeafStat* leaf) const override {}
void InitModel(decision_trees::Leaf* leaf) const override {}
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;
@ -103,10 +102,10 @@ class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator {
float GetOutputValue(const decision_trees::Leaf& leaf,
int32 o) const override;
void UpdateModel(LeafStat* leaf, const InputTarget* target,
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override;
void InitModel(LeafStat* leaf) const override {}
void InitModel(decision_trees::Leaf* leaf) const override {}
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;
@ -129,10 +128,10 @@ class RegressionLeafModelOperator : public LeafModelOperator {
// updating model and just using the seeded values. Can add this in
// with additional_data, though protobuf::Any is slow. Maybe make it
// optional. Maybe make any update optional.
void UpdateModel(LeafStat* leaf, const InputTarget* target,
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override {}
void InitModel(LeafStat* leaf) const override;
void InitModel(decision_trees::Leaf* leaf) const override;
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;

View File

@ -63,12 +63,8 @@ constexpr char kRegressionStatProto[] =
"}";
void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
std::unique_ptr<LeafStat> leaf(new LeafStat);
op->InitModel(leaf.get());
Leaf l;
op->ExportModel(*leaf, &l);
op->InitModel(&l);
// Make sure it was initialized correctly.
for (int i = 0; i < kNumClasses; ++i) {
EXPECT_EQ(op->GetOutputValue(l, i), 0);
@ -80,11 +76,10 @@ void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
new TestableInputTarget(labels, weights, 1));
// Update and check value.
op->UpdateModel(leaf.get(), target.get(), 0);
op->UpdateModel(leaf.get(), target.get(), 1);
op->UpdateModel(leaf.get(), target.get(), 2);
op->UpdateModel(&l, target.get(), 0);
op->UpdateModel(&l, target.get(), 1);
op->UpdateModel(&l, target.get(), 2);
op->ExportModel(*leaf, &l);
EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4);
}

View File

@ -71,13 +71,13 @@ void SplitCollectionOperator::ExtractFromProto(
}
void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const {
for (int i = 0; i < stats_proto->node_to_slot_size(); ++i) {
auto* new_slot = stats_proto->mutable_node_to_slot(i);
const auto& stats = stats_.at(new_slot->node_id());
for (const auto& pair : stats_) {
auto* new_slot = stats_proto->add_node_to_slot();
new_slot->set_node_id(pair.first);
if (params_.checkpoint_stats()) {
stats->PackToProto(new_slot);
pair.second->PackToProto(new_slot);
}
new_slot->set_depth(stats->depth());
new_slot->set_depth(pair.second->depth());
}
}

View File

@ -62,6 +62,14 @@ class SplitCollectionOperator {
// Create a new GrowStats for the given node id and initialize it.
virtual void InitializeSlot(int32 node_id, int32 depth);
// Called when the resource is deserialized, possibly needing an
// initialization.
virtual void MaybeInitialize() {
if (stats_.empty()) {
InitializeSlot(0, 0);
}
}
// Perform any necessary cleanup for any tracked state for the slot.
virtual void ClearSlot(int32 node_id) {
stats_.erase(node_id);

View File

@ -115,6 +115,58 @@ sparse_input_shape: The shape tensor from the SparseTensor input.
predictions: `predictions[i][j]` is the probability that input i is class j.
)doc");
REGISTER_OP("TraverseTreeV4")
.Attr("input_spec: string")
.Attr("params: string")
.Input("tree_handle: resource")
.Input("input_data: float")
.Input("sparse_input_indices: int64")
.Input("sparse_input_values: float")
.Input("sparse_input_shape: int64")
.Output("leaf_ids: int32")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle num_points = c->UnknownDim();
if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 &&
c->Value(c->Dim(c->input(1), 0)) > 0) {
num_points = c->Dim(c->input(1), 0);
}
c->set_output(0, c->Vector(num_points));
return Status::OK();
})
.Doc(R"doc(
Outputs the leaf ids for the given input data.
params: A serialized TensorForestParams proto.
tree_handle: The handle to the tree.
input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
gives the j-th feature of the i-th input.
sparse_input_indices: The indices tensor from the SparseTensor input.
sparse_input_values: The values tensor from the SparseTensor input.
sparse_input_shape: The shape tensor from the SparseTensor input.
leaf_ids: `leaf_ids[i]` is the leaf id for input i.
)doc");
REGISTER_OP("UpdateModelV4")
.Attr("params: string")
.Input("tree_handle: resource")
.Input("leaf_ids: int32")
.Input("input_labels: float")
.Input("input_weights: float")
.SetShapeFn(tensorflow::shape_inference::NoOutputs)
.Doc(R"doc(
Updates the given leaves for each example with the new labels.
params: A serialized TensorForestParams proto.
tree_handle: The handle to the tree.
leaf_ids: `leaf_ids[i]` is the leaf id for input i.
input_labels: The training batch's labels as a 1 or 2-d tensor.
'input_labels[i][j]' gives the j-th label/target for the i-th input.
input_weights: The training batch's eample weights as a 1-d tensor.
'input_weights[i]' gives the weight for the i-th input.
)doc");
REGISTER_OP("FeatureUsageCounts")
.Attr("params: string")
.Input("tree_handle: resource")

View File

@ -98,6 +98,7 @@ REGISTER_OP("ProcessInputV4")
.Input("sparse_input_shape: int64")
.Input("input_labels: float")
.Input("input_weights: float")
.Input("leaf_ids: int32")
.Output("finished_nodes: int32")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(c->UnknownDim()));
@ -122,6 +123,7 @@ input_weights: The training batch's eample weights as a 1-d tensor.
'input_weights[i]' gives the weight for the i-th input.
finished_nodes: A 1-d tensor of node ids that have finished and are ready to
grow.
leaf_ids: `leaf_ids[i]` is the leaf id for input i.
)doc");
REGISTER_OP("FinalizeTree")

View File

@ -18,12 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops
from tensorflow.contrib.tensor_forest.python.ops import stats_ops
# pylint: disable=unused-import
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import traverse_tree_v4
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_model_v4
# pylint: enable=unused-import
from tensorflow.contrib.util import loader
@ -59,13 +60,7 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject):
name: the name to save the tree variable under.
"""
self.params = params
deps = []
if stats_handle is not None:
deps.append(stats_ops.finalize_tree(
tree_handle, stats_handle,
params=params.serialized_params_proto))
with ops.control_dependencies(deps):
tensor = gen_model_ops.tree_serialize(tree_handle)
tensor = gen_model_ops.tree_serialize(tree_handle)
# slice_spec is useful for saving a slice from a variable.
# It's not meaningful the tree variable. So we just pass an empty value.
slice_spec = ""

View File

@ -27,6 +27,7 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.contrib.tensor_forest.python.ops import model_ops
from tensorflow.contrib.tensor_forest.python.ops import stats_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import tf_logging as logging
@ -240,6 +241,22 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
if input_data is None:
input_data = []
leaf_ids = model_ops.traverse_tree_v4(
self.variables.tree,
input_data,
sparse_indices,
sparse_values,
sparse_shape,
input_spec=data_spec.SerializeToString(),
params=self.params.serialized_params_proto)
update_model = model_ops.update_model_v4(
self.variables.tree,
leaf_ids,
input_labels,
input_weights,
params=self.params.serialized_params_proto)
finished_nodes = stats_ops.process_input_v4(
self.variables.tree,
self.variables.stats,
@ -249,13 +266,17 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
sparse_shape,
input_labels,
input_weights,
leaf_ids,
input_spec=data_spec.SerializeToString(),
random_seed=random_seed,
params=self.params.serialized_params_proto)
return stats_ops.grow_tree_v4(self.variables.tree, self.variables.stats,
finished_nodes,
params=self.params.serialized_params_proto)
with ops.control_dependencies([update_model]):
return stats_ops.grow_tree_v4(
self.variables.tree,
self.variables.stats,
finished_nodes,
params=self.params.serialized_params_proto)
def inference_graph(self, input_data, data_spec, sparse_features=None):
sparse_indices = []