Refactor some of TensorForest V4 to make the tree model valid during training time, instead of only after FinalizeTreeOp.
PiperOrigin-RevId: 161663317
This commit is contained in:
parent
eb1fe50da4
commit
786bf6cd65
tensorflow/contrib/tensor_forest
@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <functional>
|
||||
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
|
||||
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
|
||||
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
|
||||
@ -26,6 +27,7 @@
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorforest {
|
||||
@ -46,7 +48,7 @@ class CreateTreeVariableOp : public OpKernel {
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
|
||||
errors::InvalidArgument("Tree config must be a scalar."));
|
||||
|
||||
auto* result = new DecisionTreeResource();
|
||||
auto* result = new DecisionTreeResource(param_proto_);
|
||||
if (!ParseProtoUnlimited(result->mutable_decision_tree(),
|
||||
tree_config_t->scalar<string>()())) {
|
||||
result->Unref();
|
||||
@ -142,6 +144,16 @@ class TreeSizeOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
void TraverseTree(const DecisionTreeResource* tree_resource,
|
||||
const std::unique_ptr<TensorDataSet>& data, int32 start,
|
||||
int32 end,
|
||||
const std::function<void(int32, int32)>& set_leaf_id) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
const int32 id = tree_resource->TraverseTree(data, i, nullptr);
|
||||
set_leaf_id(i, id);
|
||||
}
|
||||
}
|
||||
|
||||
// Op for tree inference.
|
||||
class TreePredictionsV4Op : public OpKernel {
|
||||
public:
|
||||
@ -176,22 +188,49 @@ class TreePredictionsV4Op : public OpKernel {
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const int num_data = data_set_->NumItems();
|
||||
const int32 num_outputs = param_proto_.num_outputs();
|
||||
|
||||
Tensor* output_predictions = nullptr;
|
||||
TensorShape output_shape;
|
||||
output_shape.AddDim(data_set_->NumItems());
|
||||
output_shape.AddDim(param_proto_.num_outputs());
|
||||
output_shape.AddDim(num_data);
|
||||
output_shape.AddDim(num_outputs);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
|
||||
&output_predictions));
|
||||
TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>();
|
||||
|
||||
auto out = output_predictions->tensor<float, 2>();
|
||||
for (int i = 0; i < data_set_->NumItems(); ++i) {
|
||||
const int32 leaf_id =
|
||||
decision_tree_resource->TraverseTree(data_set_, i, nullptr);
|
||||
const decision_trees::Leaf& leaf =
|
||||
decision_tree_resource->get_leaf(leaf_id);
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
const int64 costPerTraverse = 500;
|
||||
auto traverse = [this, &out, decision_tree_resource, num_data](int64 start,
|
||||
int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
|
||||
static_cast<int32>(end),
|
||||
std::bind(&TreePredictionsV4Op::set_output_value, this,
|
||||
std::placeholders::_1, std::placeholders::_2,
|
||||
decision_tree_resource, &out));
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||
traverse);
|
||||
}
|
||||
|
||||
void set_output_value(int32 i, int32 id,
|
||||
DecisionTreeResource* decision_tree_resource,
|
||||
TTypes<float, 2>::Tensor* out) {
|
||||
const decision_trees::Leaf& leaf = decision_tree_resource->get_leaf(id);
|
||||
|
||||
float sum = 0;
|
||||
for (int j = 0; j < param_proto_.num_outputs(); ++j) {
|
||||
const float count = model_op_->GetOutputValue(leaf, j);
|
||||
(*out)(i, j) = count;
|
||||
sum += count;
|
||||
}
|
||||
|
||||
if (!param_proto_.is_regression() && sum > 0 && sum != 1) {
|
||||
for (int j = 0; j < param_proto_.num_outputs(); ++j) {
|
||||
const float count = model_op_->GetOutputValue(leaf, j);
|
||||
out(i, j) = count;
|
||||
(*out)(i, j) /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -203,6 +242,122 @@ class TreePredictionsV4Op : public OpKernel {
|
||||
TensorForestParams param_proto_;
|
||||
};
|
||||
|
||||
// Outputs leaf ids for the given examples.
|
||||
class TraverseTreeV4Op : public OpKernel {
|
||||
public:
|
||||
explicit TraverseTreeV4Op(OpKernelConstruction* context) : OpKernel(context) {
|
||||
string serialized_params;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
|
||||
ParseProtoUnlimited(¶m_proto_, serialized_params);
|
||||
|
||||
string serialized_proto;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
|
||||
input_spec_.ParseFromString(serialized_proto);
|
||||
|
||||
data_set_ =
|
||||
std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input_data = context->input(1);
|
||||
const Tensor& sparse_input_indices = context->input(2);
|
||||
const Tensor& sparse_input_values = context->input(3);
|
||||
const Tensor& sparse_input_shape = context->input(4);
|
||||
|
||||
data_set_->set_input_tensors(input_data, sparse_input_indices,
|
||||
sparse_input_values, sparse_input_shape);
|
||||
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const int num_data = data_set_->NumItems();
|
||||
|
||||
Tensor* output_predictions = nullptr;
|
||||
TensorShape output_shape;
|
||||
output_shape.AddDim(num_data);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
|
||||
&output_predictions));
|
||||
|
||||
auto leaf_ids = output_predictions->tensor<int32, 1>();
|
||||
|
||||
auto set_leaf_ids = [&leaf_ids](int32 i, int32 id) { leaf_ids(i) = id; };
|
||||
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
const int64 costPerTraverse = 500;
|
||||
auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
|
||||
int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
|
||||
static_cast<int32>(end), set_leaf_ids);
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||
traverse);
|
||||
}
|
||||
|
||||
private:
|
||||
tensorforest::TensorForestDataSpec input_spec_;
|
||||
std::unique_ptr<TensorDataSet> data_set_;
|
||||
TensorForestParams param_proto_;
|
||||
};
|
||||
|
||||
// Update the given leaf models using the batch of labels.
|
||||
class UpdateModelV4Op : public OpKernel {
|
||||
public:
|
||||
explicit UpdateModelV4Op(OpKernelConstruction* context) : OpKernel(context) {
|
||||
string serialized_params;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
|
||||
ParseProtoUnlimited(¶m_proto_, serialized_params);
|
||||
|
||||
model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& leaf_ids = context->input(1);
|
||||
const Tensor& input_labels = context->input(2);
|
||||
const Tensor& input_weights = context->input(3);
|
||||
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const int num_data = input_labels.shape().dim_size(0);
|
||||
const int32 label_dim =
|
||||
input_labels.shape().dims() <= 1
|
||||
? 0
|
||||
: static_cast<int>(input_labels.shape().dim_size(1));
|
||||
const int32 num_targets =
|
||||
param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
|
||||
|
||||
TensorInputTarget target(input_labels, input_weights, num_targets);
|
||||
|
||||
// TODO(gilberth): Make this thread safe and multi-thread.
|
||||
UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
|
||||
}
|
||||
|
||||
void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
|
||||
int32 start, int32 end,
|
||||
DecisionTreeResource* decision_tree_resource) {
|
||||
const auto leaves = leaf_ids.unaligned_flat<int32>();
|
||||
for (int i = start; i < end; ++i) {
|
||||
model_op_->UpdateModel(
|
||||
decision_tree_resource->get_mutable_tree_node(leaves(i))
|
||||
->mutable_leaf(),
|
||||
&target, i);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<LeafModelOperator> model_op_;
|
||||
TensorForestParams param_proto_;
|
||||
};
|
||||
|
||||
// Op for getting feature usage counts.
|
||||
class FeatureUsageCountsOp : public OpKernel {
|
||||
public:
|
||||
@ -286,8 +441,14 @@ REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), TreeSizeOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU),
|
||||
TreePredictionsV4Op);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TraverseTreeV4").Device(DEVICE_CPU),
|
||||
TraverseTreeV4Op);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU),
|
||||
FeatureUsageCountsOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("UpdateModelV4").Device(DEVICE_CPU),
|
||||
UpdateModelV4Op);
|
||||
|
||||
} // namespace tensorforest
|
||||
} // namespace tensorflow
|
||||
|
@ -64,6 +64,33 @@ TEST(ModelOpsTest, TreePredictionsV4_ShapeFn) {
|
||||
INFER_OK(op, "?;?;?;?;[10,11]", "[?,?]");
|
||||
}
|
||||
|
||||
TEST(ModelOpsTest, TraverseTreeV4_ShapeFn) {
|
||||
ShapeInferenceTestOp op("TraverseTreeV4");
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "TraverseTreeV4")
|
||||
.Input("a", 0, DT_RESOURCE)
|
||||
.Input("b", 1, DT_FLOAT)
|
||||
.Input("c", 2, DT_INT64)
|
||||
.Input("d", 3, DT_FLOAT)
|
||||
.Input("e", 5, DT_INT64)
|
||||
.Attr("input_spec", "")
|
||||
.Attr("params", "")
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
// num_points = 2, sparse shape not known
|
||||
INFER_OK(op, "?;[2,3];?;?;?", "[d1_0]");
|
||||
|
||||
// num_points = 2, sparse and dense shape rank known and > 1
|
||||
INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0]");
|
||||
|
||||
// num_points = 2, sparse shape rank known and > 1
|
||||
INFER_OK(op, "?;?;?;?;[10,11]", "[?]");
|
||||
}
|
||||
|
||||
TEST(ModelOpsTest, UpdateModelV4_ShapeFn) {
|
||||
ShapeInferenceTestOp op("UpdateModelV4");
|
||||
INFER_OK(op, "[1];?;?;?", "");
|
||||
}
|
||||
|
||||
TEST(ModelOpsTest, FeatureUsageCounts_ShapeFn) {
|
||||
ShapeInferenceTestOp op("FeatureUsageCounts");
|
||||
INFER_OK(op, "[1]", "[?]");
|
||||
|
@ -141,18 +141,6 @@ class FertileStatsDeserializeOp : public OpKernel {
|
||||
TensorForestParams param_proto_;
|
||||
};
|
||||
|
||||
void TraverseTree(const DecisionTreeResource* tree_resource,
|
||||
const std::unique_ptr<TensorDataSet>& data, int32 start,
|
||||
int32 end, std::vector<int32>* leaf_ids,
|
||||
std::vector<int32>* leaf_depths) {
|
||||
for (int i = start; i < end; ++i) {
|
||||
int32 depth;
|
||||
const int32 leaf_id = tree_resource->TraverseTree(data, i, &depth);
|
||||
(*leaf_ids)[i] = leaf_id;
|
||||
(*leaf_depths)[i] = depth;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to update a leaf's stats by acquiring its lock. If it can't be
|
||||
// acquired, put it in a waiting queue to come back to later and try the next
|
||||
// one. Once all leaf_ids have been visited, cycle through the waiting ids
|
||||
@ -160,28 +148,27 @@ void TraverseTree(const DecisionTreeResource* tree_resource,
|
||||
void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
||||
const std::unique_ptr<TensorDataSet>& data,
|
||||
const TensorInputTarget& target, int num_targets,
|
||||
const std::vector<int32>& leaf_ids,
|
||||
const std::vector<int32>& leaf_depths,
|
||||
const Tensor& leaf_ids_tensor,
|
||||
std::unordered_map<int32, std::unique_ptr<mutex>>* locks,
|
||||
mutex* set_lock, int32 start, int32 end,
|
||||
std::unordered_set<int32>* ready_to_split) {
|
||||
const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
|
||||
|
||||
// Stores leaf_id, leaf_depth, example_id for examples that are waiting
|
||||
// on another to finish.
|
||||
std::queue<std::tuple<int32, int32, int32>> waiting;
|
||||
std::queue<std::tuple<int32, int32>> waiting;
|
||||
|
||||
int32 i = start;
|
||||
while (i < end || !waiting.empty()) {
|
||||
int32 leaf_id;
|
||||
int32 leaf_depth;
|
||||
int32 example_id;
|
||||
bool was_waiting = false;
|
||||
if (i >= end) {
|
||||
std::tie(leaf_id, leaf_depth, example_id) = waiting.front();
|
||||
std::tie(leaf_id, example_id) = waiting.front();
|
||||
waiting.pop();
|
||||
was_waiting = true;
|
||||
} else {
|
||||
leaf_id = leaf_ids[i];
|
||||
leaf_depth = leaf_depths[i];
|
||||
leaf_id = leaf_ids(i);
|
||||
example_id = i;
|
||||
++i;
|
||||
}
|
||||
@ -190,14 +177,14 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
|
||||
leaf_lock->lock();
|
||||
} else {
|
||||
if (!leaf_lock->try_lock()) {
|
||||
waiting.emplace(leaf_id, leaf_depth, example_id);
|
||||
waiting.emplace(leaf_id, example_id);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_finished;
|
||||
fertile_stats_resource->AddExampleToStatsAndInitialize(
|
||||
data, &target, {example_id}, leaf_id, leaf_depth, &is_finished);
|
||||
data, &target, {example_id}, leaf_id, &is_finished);
|
||||
leaf_lock->unlock();
|
||||
if (is_finished) {
|
||||
set_lock->lock();
|
||||
@ -214,8 +201,8 @@ void UpdateStatsCollated(
|
||||
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
|
||||
int num_targets,
|
||||
const std::unordered_map<int32, std::vector<int>>& leaf_examples,
|
||||
const std::vector<int32>& leaf_depths, mutex* set_lock, int32 start,
|
||||
int32 end, std::unordered_set<int32>* ready_to_split) {
|
||||
mutex* set_lock, int32 start, int32 end,
|
||||
std::unordered_set<int32>* ready_to_split) {
|
||||
auto it = leaf_examples.begin();
|
||||
std::advance(it, start);
|
||||
auto end_it = leaf_examples.begin();
|
||||
@ -224,8 +211,7 @@ void UpdateStatsCollated(
|
||||
int32 leaf_id = it->first;
|
||||
bool is_finished;
|
||||
fertile_stats_resource->AddExampleToStatsAndInitialize(
|
||||
data, &target, it->second, leaf_id, leaf_depths[it->second[0]],
|
||||
&is_finished);
|
||||
data, &target, it->second, leaf_id, &is_finished);
|
||||
if (is_finished) {
|
||||
set_lock->lock();
|
||||
ready_to_split->insert(leaf_id);
|
||||
@ -261,6 +247,7 @@ class ProcessInputOp : public OpKernel {
|
||||
const Tensor& sparse_input_shape = context->input(5);
|
||||
const Tensor& input_labels = context->input(6);
|
||||
const Tensor& input_weights = context->input(7);
|
||||
const Tensor& leaf_ids_tensor = context->input(8);
|
||||
|
||||
data_set_->set_input_tensors(input_data, sparse_input_indices,
|
||||
sparse_input_values, sparse_input_shape);
|
||||
@ -281,22 +268,7 @@ class ProcessInputOp : public OpKernel {
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
|
||||
// First find the leaf ids for each example.
|
||||
std::vector<int32> leaf_ids(num_data);
|
||||
|
||||
// The depth of the leaf for example i.
|
||||
std::vector<int32> leaf_depths(num_data);
|
||||
|
||||
const int64 costPerTraverse = 500;
|
||||
auto traverse = [this, &leaf_ids, &leaf_depths, tree_resource, num_data](
|
||||
int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
TraverseTree(tree_resource, data_set_, static_cast<int32>(start),
|
||||
static_cast<int32>(end), &leaf_ids, &leaf_depths);
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||
traverse);
|
||||
const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
|
||||
|
||||
// Create one mutex per leaf. We need to protect access to leaf pointers,
|
||||
// so instead of grouping examples by leaf, we spread examples out among
|
||||
@ -306,10 +278,11 @@ class ProcessInputOp : public OpKernel {
|
||||
std::unordered_map<int32, std::vector<int>> leaf_examples;
|
||||
if (param_proto_.collate_examples()) {
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
leaf_examples[leaf_ids[i]].push_back(i);
|
||||
leaf_examples[leaf_ids(i)].push_back(i);
|
||||
}
|
||||
} else {
|
||||
for (const int32 id : leaf_ids) {
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
const int32 id = leaf_ids(i);
|
||||
if (FindOrNull(locks, id) == nullptr) {
|
||||
// TODO(gilberth): Consider using a memory pool for these.
|
||||
locks[id] = std::unique_ptr<mutex>(new mutex);
|
||||
@ -335,27 +308,26 @@ class ProcessInputOp : public OpKernel {
|
||||
// from a digits run on local desktop. Heuristics might be necessary
|
||||
// if it really matters that much.
|
||||
const int64 costPerUpdate = 1000;
|
||||
auto update = [this, &target, &leaf_ids, &leaf_depths, &num_targets,
|
||||
auto update = [this, &target, &leaf_ids_tensor, &num_targets,
|
||||
fertile_stats_resource, &locks, &set_lock, &ready_to_split,
|
||||
num_data](int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
UpdateStats(fertile_stats_resource, data_set_, target, num_targets,
|
||||
leaf_ids, leaf_depths, &locks, &set_lock,
|
||||
static_cast<int32>(start), static_cast<int32>(end),
|
||||
&ready_to_split);
|
||||
leaf_ids_tensor, &locks, &set_lock, static_cast<int32>(start),
|
||||
static_cast<int32>(end), &ready_to_split);
|
||||
};
|
||||
|
||||
auto update_collated = [this, &target, &leaf_ids, &num_targets,
|
||||
&leaf_depths, fertile_stats_resource, tree_resource,
|
||||
&leaf_examples, &set_lock, &ready_to_split,
|
||||
auto update_collated = [this, &target, &num_targets, fertile_stats_resource,
|
||||
tree_resource, &leaf_examples, &set_lock,
|
||||
&ready_to_split,
|
||||
num_leaves](int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_leaves);
|
||||
UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_,
|
||||
target, num_targets, leaf_examples, leaf_depths,
|
||||
&set_lock, static_cast<int32>(start),
|
||||
static_cast<int32>(end), &ready_to_split);
|
||||
target, num_targets, leaf_examples, &set_lock,
|
||||
static_cast<int32>(start), static_cast<int32>(end),
|
||||
&ready_to_split);
|
||||
};
|
||||
|
||||
if (param_proto_.collate_examples()) {
|
||||
@ -411,7 +383,8 @@ class GrowTreeOp : public OpKernel {
|
||||
const int32 num_nodes =
|
||||
static_cast<int32>(finished_nodes.shape().dim_size(0));
|
||||
|
||||
// TODO(gilberth): distribute this work over a number of threads.
|
||||
// This op takes so little of the time for one batch that it isn't worth
|
||||
// threading this.
|
||||
for (int i = 0;
|
||||
i < num_nodes &&
|
||||
tree_resource->decision_tree().decision_tree().nodes_size() <
|
||||
@ -420,16 +393,14 @@ class GrowTreeOp : public OpKernel {
|
||||
const int32 node = finished(i);
|
||||
std::unique_ptr<SplitCandidate> best(new SplitCandidate);
|
||||
int32 parent_depth;
|
||||
// TODO(gilberth): Pushing these to an output would allow the complete
|
||||
// decoupling of tree from resource.
|
||||
bool found =
|
||||
fertile_stats_resource->BestSplit(node, best.get(), &parent_depth);
|
||||
if (found) {
|
||||
std::vector<int32> new_children;
|
||||
tree_resource->SplitNode(node, best.get(), &new_children);
|
||||
fertile_stats_resource->Allocate(parent_depth, new_children);
|
||||
fertile_stats_resource->set_leaf_stat(best->left_stats(),
|
||||
new_children[0]);
|
||||
fertile_stats_resource->set_leaf_stat(best->right_stats(),
|
||||
new_children[1]);
|
||||
// We are done with best, so it is now safe to clear node.
|
||||
fertile_stats_resource->Clear(node);
|
||||
CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false);
|
||||
@ -444,20 +415,17 @@ class GrowTreeOp : public OpKernel {
|
||||
TensorForestParams param_proto_;
|
||||
};
|
||||
|
||||
void FinalizeLeaf(const LeafStat& leaf_stats, bool is_regression,
|
||||
bool drop_final_class,
|
||||
void FinalizeLeaf(bool is_regression, bool drop_final_class,
|
||||
const std::unique_ptr<LeafModelOperator>& leaf_op,
|
||||
decision_trees::Leaf* leaf) {
|
||||
leaf_op->ExportModel(leaf_stats, leaf);
|
||||
|
||||
// TODO(thomaswc): Move the rest of this into ExportModel.
|
||||
|
||||
// regression models are already stored in leaf in normalized form.
|
||||
if (is_regression) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = leaf_stats.weight_sum();
|
||||
// TODO(gilberth): Calculate the leaf's sum.
|
||||
float sum = 0;
|
||||
LOG(FATAL) << "FinalizeTreeOp is disabled for now.";
|
||||
if (sum <= 0.0) {
|
||||
LOG(WARNING) << "Leaf with sum " << sum << " has stats "
|
||||
<< leaf->ShortDebugString();
|
||||
@ -517,8 +485,7 @@ class FinalizeTreeOp : public OpKernel {
|
||||
->mutable_decision_tree()
|
||||
->mutable_nodes(i);
|
||||
if (node->has_leaf()) {
|
||||
const auto& leaf_stats = fertile_stats_resource->leaf_stat(i);
|
||||
FinalizeLeaf(leaf_stats, param_proto_.is_regression(),
|
||||
FinalizeLeaf(param_proto_.is_regression(),
|
||||
param_proto_.drop_final_class(), model_op_,
|
||||
node->mutable_leaf());
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ TEST(StatsOpsTest, GrowTreeV4_ShapeFn) {
|
||||
|
||||
TEST(StatsOpsTest, ProcessInputV4_ShapeFn) {
|
||||
ShapeInferenceTestOp op("ProcessInputV4");
|
||||
INFER_OK(op, "[1];[1];?;?;?;?;?;?", "[?]");
|
||||
INFER_OK(op, "[1];[1];?;?;?;?;?;?;?", "[?]");
|
||||
}
|
||||
|
||||
TEST(StatsOpsTest, FinalizeTree_ShapeFn) {
|
||||
|
@ -18,6 +18,7 @@ namespace tensorflow {
|
||||
namespace tensorforest {
|
||||
|
||||
using decision_trees::DecisionTree;
|
||||
using decision_trees::Leaf;
|
||||
using decision_trees::TreeNode;
|
||||
|
||||
int32 DecisionTreeResource::TraverseTree(
|
||||
@ -51,13 +52,15 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
|
||||
new_children->push_back(newid);
|
||||
TreeNode* new_left = tree->add_nodes();
|
||||
new_left->mutable_node_id()->set_value(newid++);
|
||||
new_left->mutable_leaf();
|
||||
Leaf* left_leaf = new_left->mutable_leaf();
|
||||
model_op_->ExportModel(best->left_stats(), left_leaf);
|
||||
|
||||
// right
|
||||
new_children->push_back(newid);
|
||||
TreeNode* new_right = tree->add_nodes();
|
||||
new_right->mutable_node_id()->set_value(newid);
|
||||
new_right->mutable_leaf();
|
||||
Leaf* right_leaf = new_right->mutable_leaf();
|
||||
model_op_->ExportModel(best->right_stats(), right_leaf);
|
||||
|
||||
node->clear_leaf();
|
||||
node->mutable_binary_node()->Swap(best->mutable_split());
|
||||
@ -72,7 +75,7 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
|
||||
void DecisionTreeResource::MaybeInitialize() {
|
||||
DecisionTree* tree = decision_tree_->mutable_decision_tree();
|
||||
if (tree->nodes_size() == 0) {
|
||||
tree->add_nodes()->mutable_leaf();
|
||||
model_op_->InitModel(tree->add_nodes()->mutable_leaf());
|
||||
} else if (node_evaluators_.empty()) { // reconstruct evaluators
|
||||
for (const auto& node : tree->nodes()) {
|
||||
if (node.has_leaf()) {
|
||||
|
@ -31,8 +31,10 @@ namespace tensorforest {
|
||||
class DecisionTreeResource : public ResourceBase {
|
||||
public:
|
||||
// Constructor.
|
||||
explicit DecisionTreeResource()
|
||||
: decision_tree_(new decision_trees::Model()) {}
|
||||
explicit DecisionTreeResource(const TensorForestParams& params)
|
||||
: params_(params), decision_tree_(new decision_trees::Model()) {
|
||||
model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
|
||||
}
|
||||
|
||||
string DebugString() override {
|
||||
return strings::StrCat("DecisionTree[size=",
|
||||
@ -79,7 +81,9 @@ class DecisionTreeResource : public ResourceBase {
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
const TensorForestParams params_;
|
||||
std::unique_ptr<decision_trees::Model> decision_tree_;
|
||||
std::shared_ptr<LeafModelOperator> model_op_;
|
||||
std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_;
|
||||
};
|
||||
|
||||
|
@ -20,14 +20,8 @@ namespace tensorflow {
|
||||
namespace tensorforest {
|
||||
|
||||
void FertileStatsResource::AddExampleToStatsAndInitialize(
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, const std::vector<int>& examples,
|
||||
int32 node_id, int32 node_depth, bool* is_finished) {
|
||||
// Set leaf's counts for calculating probabilities.
|
||||
for (int example : examples) {
|
||||
model_op_->UpdateModel(&leaf_stats_[node_id], target, example);
|
||||
}
|
||||
|
||||
const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
|
||||
const std::vector<int>& examples, int32 node_id, bool* is_finished) {
|
||||
// Update stats or initialize if needed.
|
||||
if (collection_op_->IsInitialized(node_id)) {
|
||||
collection_op_->AddExample(input_data, target, examples, node_id);
|
||||
@ -47,8 +41,6 @@ void FertileStatsResource::AddExampleToStatsAndInitialize(
|
||||
}
|
||||
|
||||
void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) {
|
||||
leaf_stats_[node_id] = LeafStat();
|
||||
model_op_->InitModel(&leaf_stats_[node_id]);
|
||||
collection_op_->InitializeSlot(node_id, depth);
|
||||
}
|
||||
|
||||
@ -62,7 +54,6 @@ void FertileStatsResource::Allocate(int32 parent_depth,
|
||||
|
||||
void FertileStatsResource::Clear(int32 node) {
|
||||
collection_op_->ClearSlot(node);
|
||||
leaf_stats_.erase(node);
|
||||
}
|
||||
|
||||
bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
|
||||
@ -71,27 +62,16 @@ bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
|
||||
}
|
||||
|
||||
void FertileStatsResource::MaybeInitialize() {
|
||||
if (leaf_stats_.empty()) {
|
||||
AllocateNode(0, 0);
|
||||
}
|
||||
collection_op_->MaybeInitialize();
|
||||
}
|
||||
|
||||
void FertileStatsResource::ExtractFromProto(const FertileStats& stats) {
|
||||
collection_op_ =
|
||||
SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_);
|
||||
collection_op_->ExtractFromProto(stats);
|
||||
for (int i = 0; i < stats.node_to_slot_size(); ++i) {
|
||||
const auto& slot = stats.node_to_slot(i);
|
||||
leaf_stats_[slot.node_id()] = slot.leaf_stats();
|
||||
}
|
||||
}
|
||||
|
||||
void FertileStatsResource::PackToProto(FertileStats* stats) const {
|
||||
for (const auto& entry : leaf_stats_) {
|
||||
auto* slot = stats->add_node_to_slot();
|
||||
*slot->mutable_leaf_stats() = entry.second;
|
||||
slot->set_node_id(entry.first);
|
||||
}
|
||||
collection_op_->PackToProto(stats);
|
||||
}
|
||||
} // namespace tensorforest
|
||||
|
@ -51,7 +51,6 @@ class FertileStatsResource : public ResourceBase {
|
||||
// Resets the resource and frees the proto.
|
||||
// Caller needs to hold the mutex lock while calling this.
|
||||
void Reset() {
|
||||
leaf_stats_.clear();
|
||||
}
|
||||
|
||||
// Reset the stats for a node, but leave the leaf_stats intact.
|
||||
@ -71,7 +70,7 @@ class FertileStatsResource : public ResourceBase {
|
||||
void AddExampleToStatsAndInitialize(
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, const std::vector<int>& examples,
|
||||
int32 node_id, int32 node_depth, bool* is_finished);
|
||||
int32 node_id, bool* is_finished);
|
||||
|
||||
// Allocate a fertile slot for each ready node, then new children up to
|
||||
// max_fertile_nodes_.
|
||||
@ -85,19 +84,11 @@ class FertileStatsResource : public ResourceBase {
|
||||
// was found.
|
||||
bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth);
|
||||
|
||||
const LeafStat& leaf_stat(int32 node_id) {
|
||||
return leaf_stats_[node_id];
|
||||
}
|
||||
|
||||
void set_leaf_stat(const LeafStat& stat, int32 node_id) {
|
||||
leaf_stats_[node_id] = stat;
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::shared_ptr<LeafModelOperator> model_op_;
|
||||
std::unique_ptr<SplitCollectionOperator> collection_op_;
|
||||
std::unordered_map<int32, LeafStat> leaf_stats_;
|
||||
const TensorForestParams params_;
|
||||
|
||||
void AllocateNode(int32 node_id, int32 depth);
|
||||
|
@ -17,6 +17,8 @@
|
||||
namespace tensorflow {
|
||||
namespace tensorforest {
|
||||
|
||||
using decision_trees::Leaf;
|
||||
|
||||
std::unique_ptr<LeafModelOperator>
|
||||
LeafModelOperatorFactory::CreateLeafModelOperator(
|
||||
const TensorForestParams& params) {
|
||||
@ -50,24 +52,21 @@ float DenseClassificationLeafModelOperator::GetOutputValue(
|
||||
}
|
||||
|
||||
void DenseClassificationLeafModelOperator::UpdateModel(
|
||||
LeafStat* leaf, const InputTarget* target,
|
||||
int example) const {
|
||||
Leaf* leaf, const InputTarget* target, int example) const {
|
||||
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
|
||||
QCHECK_LT(int_label, params_.num_outputs())
|
||||
<< "Got label greater than indicated number of classes. Is "
|
||||
"params.num_classes set correctly?";
|
||||
QCHECK_GE(int_label, 0);
|
||||
auto* val = leaf->mutable_classification()->mutable_dense_counts()
|
||||
->mutable_value(int_label);
|
||||
auto* val = leaf->mutable_vector()->mutable_value(int_label);
|
||||
|
||||
float weight = target->GetTargetWeight(example);
|
||||
val->set_float_value(val->float_value() + weight);
|
||||
leaf->set_weight_sum(leaf->weight_sum() + weight);
|
||||
}
|
||||
|
||||
void DenseClassificationLeafModelOperator::InitModel(
|
||||
LeafStat* leaf) const {
|
||||
void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
|
||||
for (int i = 0; i < params_.num_outputs(); ++i) {
|
||||
leaf->mutable_classification()->mutable_dense_counts()->add_value();
|
||||
leaf->mutable_vector()->add_value();
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,17 +87,15 @@ float SparseClassificationLeafModelOperator::GetOutputValue(
|
||||
}
|
||||
|
||||
void SparseClassificationLeafModelOperator::UpdateModel(
|
||||
LeafStat* leaf, const InputTarget* target,
|
||||
int example) const {
|
||||
Leaf* leaf, const InputTarget* target, int example) const {
|
||||
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
|
||||
QCHECK_LT(int_label, params_.num_outputs())
|
||||
<< "Got label greater than indicated number of classes. Is "
|
||||
"params.num_classes set correctly?";
|
||||
QCHECK_GE(int_label, 0);
|
||||
const float weight = target->GetTargetWeight(example);
|
||||
leaf->set_weight_sum(leaf->weight_sum() + weight);
|
||||
auto value_map = leaf->mutable_classification()->mutable_sparse_counts()
|
||||
->mutable_sparse_value();
|
||||
|
||||
auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
|
||||
auto it = value_map->find(int_label);
|
||||
if (it == value_map->end()) {
|
||||
(*value_map)[int_label].set_float_value(weight);
|
||||
@ -123,8 +120,8 @@ float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
|
||||
}
|
||||
|
||||
void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
|
||||
LeafStat* leaf, const InputTarget* target, int example) const {
|
||||
if (leaf->classification().has_dense_counts()) {
|
||||
Leaf* leaf, const InputTarget* target, int example) const {
|
||||
if (leaf->has_vector()) {
|
||||
return dense_->UpdateModel(leaf, target, example);
|
||||
} else {
|
||||
return sparse_->UpdateModel(leaf, target, example);
|
||||
@ -146,15 +143,15 @@ float RegressionLeafModelOperator::GetOutputValue(
|
||||
return leaf.vector().value(o).float_value();
|
||||
}
|
||||
|
||||
void RegressionLeafModelOperator::InitModel(
|
||||
LeafStat* leaf) const {
|
||||
void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
|
||||
for (int i = 0; i < params_.num_outputs(); ++i) {
|
||||
leaf->mutable_regression()->mutable_mean_output()->add_value();
|
||||
leaf->mutable_vector()->add_value();
|
||||
}
|
||||
}
|
||||
|
||||
void RegressionLeafModelOperator::ExportModel(
|
||||
const LeafStat& stat, decision_trees::Leaf* leaf) const {
|
||||
leaf->clear_vector();
|
||||
for (int i = 0; i < params_.num_outputs(); ++i) {
|
||||
const float new_val =
|
||||
stat.regression().mean_output().value(i).float_value() /
|
||||
|
@ -42,12 +42,11 @@ class LeafModelOperator {
|
||||
int32 o) const = 0;
|
||||
|
||||
// Update the given Leaf's model with the given example.
|
||||
virtual void UpdateModel(LeafStat* leaf,
|
||||
const InputTarget* target,
|
||||
int example) const = 0;
|
||||
virtual void UpdateModel(decision_trees::Leaf* leaf,
|
||||
const InputTarget* target, int example) const = 0;
|
||||
|
||||
// Initialize an empty Leaf model.
|
||||
virtual void InitModel(LeafStat* leaf) const = 0;
|
||||
virtual void InitModel(decision_trees::Leaf* leaf) const = 0;
|
||||
|
||||
virtual void ExportModel(const LeafStat& stat,
|
||||
decision_trees::Leaf* leaf) const = 0;
|
||||
@ -65,10 +64,10 @@ class DenseClassificationLeafModelOperator : public LeafModelOperator {
|
||||
float GetOutputValue(const decision_trees::Leaf& leaf,
|
||||
int32 o) const override;
|
||||
|
||||
void UpdateModel(LeafStat* leaf, const InputTarget* target,
|
||||
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
|
||||
int example) const override;
|
||||
|
||||
void InitModel(LeafStat* leaf) const override;
|
||||
void InitModel(decision_trees::Leaf* leaf) const override;
|
||||
|
||||
void ExportModel(const LeafStat& stat,
|
||||
decision_trees::Leaf* leaf) const override;
|
||||
@ -84,10 +83,10 @@ class SparseClassificationLeafModelOperator : public LeafModelOperator {
|
||||
float GetOutputValue(const decision_trees::Leaf& leaf,
|
||||
int32 o) const override;
|
||||
|
||||
void UpdateModel(LeafStat* leaf, const InputTarget* target,
|
||||
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
|
||||
int example) const override;
|
||||
|
||||
void InitModel(LeafStat* leaf) const override {}
|
||||
void InitModel(decision_trees::Leaf* leaf) const override {}
|
||||
|
||||
void ExportModel(const LeafStat& stat,
|
||||
decision_trees::Leaf* leaf) const override;
|
||||
@ -103,10 +102,10 @@ class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator {
|
||||
float GetOutputValue(const decision_trees::Leaf& leaf,
|
||||
int32 o) const override;
|
||||
|
||||
void UpdateModel(LeafStat* leaf, const InputTarget* target,
|
||||
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
|
||||
int example) const override;
|
||||
|
||||
void InitModel(LeafStat* leaf) const override {}
|
||||
void InitModel(decision_trees::Leaf* leaf) const override {}
|
||||
|
||||
void ExportModel(const LeafStat& stat,
|
||||
decision_trees::Leaf* leaf) const override;
|
||||
@ -129,10 +128,10 @@ class RegressionLeafModelOperator : public LeafModelOperator {
|
||||
// updating model and just using the seeded values. Can add this in
|
||||
// with additional_data, though protobuf::Any is slow. Maybe make it
|
||||
// optional. Maybe make any update optional.
|
||||
void UpdateModel(LeafStat* leaf, const InputTarget* target,
|
||||
void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
|
||||
int example) const override {}
|
||||
|
||||
void InitModel(LeafStat* leaf) const override;
|
||||
void InitModel(decision_trees::Leaf* leaf) const override;
|
||||
|
||||
void ExportModel(const LeafStat& stat,
|
||||
decision_trees::Leaf* leaf) const override;
|
||||
|
@ -63,12 +63,8 @@ constexpr char kRegressionStatProto[] =
|
||||
"}";
|
||||
|
||||
void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
|
||||
std::unique_ptr<LeafStat> leaf(new LeafStat);
|
||||
op->InitModel(leaf.get());
|
||||
|
||||
Leaf l;
|
||||
op->ExportModel(*leaf, &l);
|
||||
|
||||
op->InitModel(&l);
|
||||
// Make sure it was initialized correctly.
|
||||
for (int i = 0; i < kNumClasses; ++i) {
|
||||
EXPECT_EQ(op->GetOutputValue(l, i), 0);
|
||||
@ -80,11 +76,10 @@ void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
|
||||
new TestableInputTarget(labels, weights, 1));
|
||||
|
||||
// Update and check value.
|
||||
op->UpdateModel(leaf.get(), target.get(), 0);
|
||||
op->UpdateModel(leaf.get(), target.get(), 1);
|
||||
op->UpdateModel(leaf.get(), target.get(), 2);
|
||||
op->UpdateModel(&l, target.get(), 0);
|
||||
op->UpdateModel(&l, target.get(), 1);
|
||||
op->UpdateModel(&l, target.get(), 2);
|
||||
|
||||
op->ExportModel(*leaf, &l);
|
||||
EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4);
|
||||
}
|
||||
|
||||
|
@ -71,13 +71,13 @@ void SplitCollectionOperator::ExtractFromProto(
|
||||
}
|
||||
|
||||
void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const {
|
||||
for (int i = 0; i < stats_proto->node_to_slot_size(); ++i) {
|
||||
auto* new_slot = stats_proto->mutable_node_to_slot(i);
|
||||
const auto& stats = stats_.at(new_slot->node_id());
|
||||
for (const auto& pair : stats_) {
|
||||
auto* new_slot = stats_proto->add_node_to_slot();
|
||||
new_slot->set_node_id(pair.first);
|
||||
if (params_.checkpoint_stats()) {
|
||||
stats->PackToProto(new_slot);
|
||||
pair.second->PackToProto(new_slot);
|
||||
}
|
||||
new_slot->set_depth(stats->depth());
|
||||
new_slot->set_depth(pair.second->depth());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,6 +62,14 @@ class SplitCollectionOperator {
|
||||
// Create a new GrowStats for the given node id and initialize it.
|
||||
virtual void InitializeSlot(int32 node_id, int32 depth);
|
||||
|
||||
// Called when the resource is deserialized, possibly needing an
|
||||
// initialization.
|
||||
virtual void MaybeInitialize() {
|
||||
if (stats_.empty()) {
|
||||
InitializeSlot(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Perform any necessary cleanup for any tracked state for the slot.
|
||||
virtual void ClearSlot(int32 node_id) {
|
||||
stats_.erase(node_id);
|
||||
|
@ -115,6 +115,58 @@ sparse_input_shape: The shape tensor from the SparseTensor input.
|
||||
predictions: `predictions[i][j]` is the probability that input i is class j.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("TraverseTreeV4")
|
||||
.Attr("input_spec: string")
|
||||
.Attr("params: string")
|
||||
.Input("tree_handle: resource")
|
||||
.Input("input_data: float")
|
||||
.Input("sparse_input_indices: int64")
|
||||
.Input("sparse_input_values: float")
|
||||
.Input("sparse_input_shape: int64")
|
||||
.Output("leaf_ids: int32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
DimensionHandle num_points = c->UnknownDim();
|
||||
|
||||
if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 &&
|
||||
c->Value(c->Dim(c->input(1), 0)) > 0) {
|
||||
num_points = c->Dim(c->input(1), 0);
|
||||
}
|
||||
|
||||
c->set_output(0, c->Vector(num_points));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Outputs the leaf ids for the given input data.
|
||||
|
||||
params: A serialized TensorForestParams proto.
|
||||
tree_handle: The handle to the tree.
|
||||
input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
|
||||
gives the j-th feature of the i-th input.
|
||||
sparse_input_indices: The indices tensor from the SparseTensor input.
|
||||
sparse_input_values: The values tensor from the SparseTensor input.
|
||||
sparse_input_shape: The shape tensor from the SparseTensor input.
|
||||
leaf_ids: `leaf_ids[i]` is the leaf id for input i.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("UpdateModelV4")
|
||||
.Attr("params: string")
|
||||
.Input("tree_handle: resource")
|
||||
.Input("leaf_ids: int32")
|
||||
.Input("input_labels: float")
|
||||
.Input("input_weights: float")
|
||||
.SetShapeFn(tensorflow::shape_inference::NoOutputs)
|
||||
.Doc(R"doc(
|
||||
Updates the given leaves for each example with the new labels.
|
||||
|
||||
params: A serialized TensorForestParams proto.
|
||||
tree_handle: The handle to the tree.
|
||||
leaf_ids: `leaf_ids[i]` is the leaf id for input i.
|
||||
input_labels: The training batch's labels as a 1 or 2-d tensor.
|
||||
'input_labels[i][j]' gives the j-th label/target for the i-th input.
|
||||
input_weights: The training batch's eample weights as a 1-d tensor.
|
||||
'input_weights[i]' gives the weight for the i-th input.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FeatureUsageCounts")
|
||||
.Attr("params: string")
|
||||
.Input("tree_handle: resource")
|
||||
|
@ -98,6 +98,7 @@ REGISTER_OP("ProcessInputV4")
|
||||
.Input("sparse_input_shape: int64")
|
||||
.Input("input_labels: float")
|
||||
.Input("input_weights: float")
|
||||
.Input("leaf_ids: int32")
|
||||
.Output("finished_nodes: int32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(c->UnknownDim()));
|
||||
@ -122,6 +123,7 @@ input_weights: The training batch's eample weights as a 1-d tensor.
|
||||
'input_weights[i]' gives the weight for the i-th input.
|
||||
finished_nodes: A 1-d tensor of node ids that have finished and are ready to
|
||||
grow.
|
||||
leaf_ids: `leaf_ids[i]` is the leaf id for input i.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FinalizeTree")
|
||||
|
@ -18,12 +18,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops
|
||||
from tensorflow.contrib.tensor_forest.python.ops import stats_ops
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts
|
||||
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import traverse_tree_v4
|
||||
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4
|
||||
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size
|
||||
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_model_v4
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
@ -59,13 +60,7 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject):
|
||||
name: the name to save the tree variable under.
|
||||
"""
|
||||
self.params = params
|
||||
deps = []
|
||||
if stats_handle is not None:
|
||||
deps.append(stats_ops.finalize_tree(
|
||||
tree_handle, stats_handle,
|
||||
params=params.serialized_params_proto))
|
||||
with ops.control_dependencies(deps):
|
||||
tensor = gen_model_ops.tree_serialize(tree_handle)
|
||||
tensor = gen_model_ops.tree_serialize(tree_handle)
|
||||
# slice_spec is useful for saving a slice from a variable.
|
||||
# It's not meaningful the tree variable. So we just pass an empty value.
|
||||
slice_spec = ""
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||
from tensorflow.contrib.tensor_forest.python.ops import model_ops
|
||||
from tensorflow.contrib.tensor_forest.python.ops import stats_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
@ -240,6 +241,22 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
|
||||
if input_data is None:
|
||||
input_data = []
|
||||
|
||||
leaf_ids = model_ops.traverse_tree_v4(
|
||||
self.variables.tree,
|
||||
input_data,
|
||||
sparse_indices,
|
||||
sparse_values,
|
||||
sparse_shape,
|
||||
input_spec=data_spec.SerializeToString(),
|
||||
params=self.params.serialized_params_proto)
|
||||
|
||||
update_model = model_ops.update_model_v4(
|
||||
self.variables.tree,
|
||||
leaf_ids,
|
||||
input_labels,
|
||||
input_weights,
|
||||
params=self.params.serialized_params_proto)
|
||||
|
||||
finished_nodes = stats_ops.process_input_v4(
|
||||
self.variables.tree,
|
||||
self.variables.stats,
|
||||
@ -249,13 +266,17 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
|
||||
sparse_shape,
|
||||
input_labels,
|
||||
input_weights,
|
||||
leaf_ids,
|
||||
input_spec=data_spec.SerializeToString(),
|
||||
random_seed=random_seed,
|
||||
params=self.params.serialized_params_proto)
|
||||
|
||||
return stats_ops.grow_tree_v4(self.variables.tree, self.variables.stats,
|
||||
finished_nodes,
|
||||
params=self.params.serialized_params_proto)
|
||||
with ops.control_dependencies([update_model]):
|
||||
return stats_ops.grow_tree_v4(
|
||||
self.variables.tree,
|
||||
self.variables.stats,
|
||||
finished_nodes,
|
||||
params=self.params.serialized_params_proto)
|
||||
|
||||
def inference_graph(self, input_data, data_spec, sparse_features=None):
|
||||
sparse_indices = []
|
||||
|
Loading…
Reference in New Issue
Block a user