Merge pull request #21803 from yupbank:add-tensor-forest-classification-inference
PiperOrigin-RevId: 222128700
This commit is contained in:
commit
68900419ba
tensorflow
core
BUILD
api_def/base_api
api_def_TensorForestCreateTreeVariable.pbtxtapi_def_TensorForestTreeDeserialize.pbtxtapi_def_TensorForestTreeIsInitializedOp.pbtxtapi_def_TensorForestTreePredict.pbtxtapi_def_TensorForestTreeResourceHandleOp.pbtxtapi_def_TensorForestTreeSerialize.pbtxtapi_def_TensorForestTreeSize.pbtxt
kernels
ops
python
@ -1052,6 +1052,7 @@ tf_gen_op_libs(
|
||||
"batch_ops",
|
||||
"bitwise_ops",
|
||||
"boosted_trees_ops",
|
||||
"tensor_forest_ops",
|
||||
"candidate_sampling_ops",
|
||||
"checkpoint_ops",
|
||||
"collective_ops",
|
||||
@ -1201,6 +1202,7 @@ cc_library(
|
||||
":batch_ops_op_lib",
|
||||
":bitwise_ops_op_lib",
|
||||
":boosted_trees_ops_op_lib",
|
||||
":tensor_forest_ops_op_lib",
|
||||
":candidate_sampling_ops_op_lib",
|
||||
":checkpoint_ops_op_lib",
|
||||
":collective_ops_op_lib",
|
||||
@ -1354,6 +1356,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:batch_kernels",
|
||||
"//tensorflow/core/kernels:bincount_op",
|
||||
"//tensorflow/core/kernels:boosted_trees_ops",
|
||||
"//tensorflow/core/kernels:tensor_forest_ops",
|
||||
"//tensorflow/core/kernels:candidate_sampler_ops",
|
||||
"//tensorflow/core/kernels:checkpoint_ops",
|
||||
"//tensorflow/core/kernels:collective_ops",
|
||||
|
@ -0,0 +1,17 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestCreateTreeVariable"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource to be created.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tree_config"
|
||||
description: <<END
|
||||
Serialized proto string of the boosted_trees.Tree.
|
||||
END
|
||||
}
|
||||
summary: "Creates a tree resource and returns a handle to it."
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeDeserialize"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource to be restored.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tree_config"
|
||||
description: <<END
|
||||
Serialied proto string of the boosted_trees.Tree proto.
|
||||
END
|
||||
}
|
||||
summary: "Deserializes a proto into the tree handle"
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeIsInitializedOp"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "is_initialized"
|
||||
description: <<END
|
||||
Whether the tree is initialized.
|
||||
END
|
||||
}
|
||||
summary: "Checks whether a tree has been initialized."
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreePredict"
|
||||
visibility: HIDDEN
|
||||
attr {
|
||||
name: "logits_dimension"
|
||||
description: <<END
|
||||
Scalar, dimension of the logits.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "dense_features"
|
||||
description: <<END
|
||||
Rank 2 dense features tensor.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "logits"
|
||||
description: <<END
|
||||
The logits predictions from the tree for each instance in the batch.
|
||||
END
|
||||
}
|
||||
summary: "Output the logits for the given input data"
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeResourceHandleOp"
|
||||
visibility: HIDDEN
|
||||
summary: "Creates a handle to a TensorForestTreeResource"
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeSerialize"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource to be serialized.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "tree_config"
|
||||
description: <<END
|
||||
Serialied proto string of the tree resource.
|
||||
END
|
||||
}
|
||||
summary: "Serializes the tree handle to a proto"
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeSize"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "tree_size"
|
||||
description: <<END
|
||||
The size of the tree.
|
||||
END
|
||||
}
|
||||
summary: "Get the number of nodes in a tree"
|
||||
}
|
@ -6735,6 +6735,13 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "tensor_forest_ops",
|
||||
deps = [
|
||||
"//tensorflow/core/kernels/tensor_forest:tensor_forest_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dataset_ops",
|
||||
deps = [
|
||||
|
@ -2,7 +2,10 @@
|
||||
# OpKernels for boosted trees ops.
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
@ -12,6 +12,7 @@ message Node {
|
||||
Leaf leaf = 1;
|
||||
BucketizedSplit bucketized_split = 2;
|
||||
CategoricalSplit categorical_split = 3;
|
||||
DenseSplit dense_split = 4;
|
||||
}
|
||||
NodeMetadata metadata = 777;
|
||||
}
|
||||
@ -70,6 +71,19 @@ message CategoricalSplit {
|
||||
int32 right_id = 4;
|
||||
}
|
||||
|
||||
// TODO(nponomareva): move out of boosted_trees and rename to trees.proto
|
||||
message DenseSplit {
|
||||
// Float feature column and split threshold describing
|
||||
// the rule feature <= threshold.
|
||||
int32 feature_id = 1;
|
||||
float threshold = 2;
|
||||
|
||||
// Node children indexing into a contiguous
|
||||
// vector of nodes starting from the root.
|
||||
int32 left_id = 3;
|
||||
int32 right_id = 4;
|
||||
}
|
||||
|
||||
// Tree describes a list of connected nodes.
|
||||
// Node 0 must be the root and can carry any payload including a leaf
|
||||
// in the case of representing the bias.
|
||||
|
53
tensorflow/core/kernels/tensor_forest/BUILD
Normal file
53
tensorflow/core/kernels/tensor_forest/BUILD
Normal file
@ -0,0 +1,53 @@
|
||||
# Description:
|
||||
# OpKernels for tensor forest ops.
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
|
||||
cc_library(
|
||||
name = "resources",
|
||||
srcs = ["resources.cc"],
|
||||
hdrs = ["resources.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "resource_ops",
|
||||
srcs = ["resource_ops.cc"],
|
||||
deps = [
|
||||
":resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensor_forest_ops_op_lib",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "prediction_ops",
|
||||
srcs = ["prediction_ops.cc"],
|
||||
deps = [
|
||||
":resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensor_forest_ops_op_lib",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "tensor_forest_ops",
|
||||
deps = [
|
||||
":prediction_ops",
|
||||
":resource_ops",
|
||||
],
|
||||
)
|
93
tensorflow/core/kernels/tensor_forest/prediction_ops.cc
Normal file
93
tensorflow/core/kernels/tensor_forest/prediction_ops.cc
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorForestTreePredictOp : public OpKernel {
|
||||
public:
|
||||
explicit TensorForestTreePredictOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("logits_dimension", &logits_dimension_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const Tensor* dense_features_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("dense_features", &dense_features_t));
|
||||
|
||||
auto dense_features = dense_features_t->matrix<float>();
|
||||
const int32 batch_size = dense_features_t->dim_size(0);
|
||||
|
||||
Tensor* output_predictions = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, {batch_size, logits_dimension_},
|
||||
&output_predictions));
|
||||
auto out = output_predictions->matrix<float>();
|
||||
|
||||
if (decision_tree_resource->get_size() <= 0) {
|
||||
out.setZero();
|
||||
return;
|
||||
}
|
||||
auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
const int32 num_threads = worker_threads->num_threads;
|
||||
|
||||
// TODO(yupbank): This was from contrib version.
|
||||
// This cost would probably depend on the depth of the tree we have.
|
||||
// We will need to run it on a number of trees of diff depth
|
||||
// and see the num of cpu cycles
|
||||
const int64 cost_per_traverse = 500;
|
||||
auto traverse = [this, &out, &dense_features, decision_tree_resource,
|
||||
batch_size](int64 start, int64 end) {
|
||||
DCHECK_LE(start, end) << "Start exceeding End";
|
||||
DCHECK_LE(end, batch_size) << "End exceeding batch size";
|
||||
for (int example_id = start; example_id < end; ++example_id) {
|
||||
const int32 leaf_id =
|
||||
decision_tree_resource->TraverseTree(example_id, &dense_features);
|
||||
set_output_value(example_id, leaf_id, decision_tree_resource, &out);
|
||||
}
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, batch_size, cost_per_traverse,
|
||||
traverse);
|
||||
};
|
||||
|
||||
void set_output_value(const int32 example_id, const int32 leaf_id,
|
||||
const TensorForestTreeResource* decision_tree_resource,
|
||||
TTypes<float>::Matrix* out) const {
|
||||
for (int j = 0; j < logits_dimension_; ++j) {
|
||||
const float logit = decision_tree_resource->get_prediction(leaf_id, j);
|
||||
(*out)(example_id, j) = logit;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int32 logits_dimension_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorForestTreePredict").Device(DEVICE_CPU),
|
||||
TensorForestTreePredictOp);
|
||||
|
||||
} // namespace tensorflow
|
136
tensorflow/core/kernels/tensor_forest/resource_ops.cc
Normal file
136
tensorflow/core/kernels/tensor_forest/resource_ops.cc
Normal file
@ -0,0 +1,136 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorForestCreateTreeVariableOp : public OpKernel {
|
||||
public:
|
||||
explicit TensorForestCreateTreeVariableOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* tree_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||
|
||||
auto* const result = new TensorForestTreeResource();
|
||||
|
||||
if (!result->InitFromSerialized(tree_config_t->scalar<string>()())) {
|
||||
result->Unref();
|
||||
OP_REQUIRES(context, false,
|
||||
errors::InvalidArgument("Unable to parse tree config."));
|
||||
}
|
||||
|
||||
// Only create one, if one does not exist already. Report status for all
|
||||
// other exceptions.
|
||||
auto status = CreateResource(context, HandleFromInput(context, 0), result);
|
||||
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||
OP_REQUIRES(context, false, status);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Op for serializing a model.
|
||||
class TensorForestTreeSerializeOp : public OpKernel {
|
||||
public:
|
||||
explicit TensorForestTreeSerializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
Tensor* output_config_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||
output_config_t->scalar<string>()() =
|
||||
decision_tree_resource->decision_tree().SerializeAsString();
|
||||
}
|
||||
};
|
||||
|
||||
// Op for deserializing a tree variable from a checkpoint.
|
||||
class TensorForestTreeDeserializeOp : public OpKernel {
|
||||
public:
|
||||
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const Tensor* tree_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||
|
||||
// Deallocate all the previous objects on the resource.
|
||||
decision_tree_resource->Reset();
|
||||
|
||||
if (!decision_tree_resource->InitFromSerialized(
|
||||
tree_config_t->scalar<string>()())) {
|
||||
OP_REQUIRES(context, false,
|
||||
errors::InvalidArgument("Unable to parse tree config."));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Op for getting tree size.
|
||||
class TensorForestTreeSizeOp : public OpKernel {
|
||||
public:
|
||||
explicit TensorForestTreeSizeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorForestTreeResource* decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_resource));
|
||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
Tensor* output_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape(), &output_t));
|
||||
output_t->scalar<int32>()() = decision_tree_resource->get_size();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_RESOURCE_HANDLE_KERNEL(TensorForestTreeResource);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("TensorForestTreeIsInitializedOp").Device(DEVICE_CPU),
|
||||
IsResourceInitialized<TensorForestTreeResource>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("TensorForestCreateTreeVariable").Device(DEVICE_CPU),
|
||||
TensorForestCreateTreeVariableOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeSerialize").Device(DEVICE_CPU),
|
||||
TensorForestTreeSerializeOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeDeserialize").Device(DEVICE_CPU),
|
||||
TensorForestTreeDeserializeOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeSize").Device(DEVICE_CPU),
|
||||
TensorForestTreeSizeOp);
|
||||
|
||||
} // namespace tensorflow
|
71
tensorflow/core/kernels/tensor_forest/resources.cc
Normal file
71
tensorflow/core/kernels/tensor_forest/resources.cc
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const boosted_trees::Tree& TensorForestTreeResource::decision_tree() const {
|
||||
return *decision_tree_;
|
||||
}
|
||||
|
||||
const int32 TensorForestTreeResource::get_size() const {
|
||||
return decision_tree_->nodes_size();
|
||||
}
|
||||
|
||||
TensorForestTreeResource::TensorForestTreeResource()
|
||||
: decision_tree_(
|
||||
protobuf::Arena::CreateMessage<boosted_trees::Tree>(&arena_)) {}
|
||||
|
||||
const float TensorForestTreeResource::get_prediction(
|
||||
const int32 id, const int32 dimension_id) const {
|
||||
return decision_tree_->nodes(id).leaf().vector().value(dimension_id);
|
||||
}
|
||||
|
||||
const int32 TensorForestTreeResource::TraverseTree(
|
||||
const int32 example_id,
|
||||
const TTypes<float>::ConstMatrix* dense_data) const {
|
||||
using boosted_trees::Node;
|
||||
using boosted_trees::Tree;
|
||||
int32 current_id = 0;
|
||||
while (true) {
|
||||
const Node& current = decision_tree_->nodes(current_id);
|
||||
if (current.has_leaf()) {
|
||||
return current_id;
|
||||
}
|
||||
DCHECK_EQ(current.node_case(), Node::kDenseSplit);
|
||||
const auto& split = current.dense_split();
|
||||
|
||||
if ((*dense_data)(example_id, split.feature_id()) <= split.threshold()) {
|
||||
current_id = split.left_id();
|
||||
} else {
|
||||
current_id = split.right_id();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorForestTreeResource::InitFromSerialized(const string& serialized) {
|
||||
return ParseProtoUnlimited(decision_tree_, serialized);
|
||||
}
|
||||
|
||||
void TensorForestTreeResource::Reset() {
|
||||
arena_.Reset();
|
||||
DCHECK_EQ(0, arena_.SpaceAllocated());
|
||||
decision_tree_ = protobuf::Arena::CreateMessage<boosted_trees::Tree>(&arena_);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
65
tensorflow/core/kernels/tensor_forest/resources.h
Normal file
65
tensorflow/core/kernels/tensor_forest/resources.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forward declaration for proto class Tree.
|
||||
namespace boosted_trees {
|
||||
class Tree;
|
||||
} // namespace boosted_trees
|
||||
|
||||
// Keep a tree ensemble in memory for efficient evaluation and mutation.
|
||||
class TensorForestTreeResource : public ResourceBase {
|
||||
public:
|
||||
TensorForestTreeResource();
|
||||
|
||||
string DebugString() override {
|
||||
return strings::StrCat("TensorForestTree[size=", get_size(), "]");
|
||||
}
|
||||
|
||||
mutex* get_mutex() { return &mu_; }
|
||||
|
||||
bool InitFromSerialized(const string& serialized);
|
||||
|
||||
// Resets the resource and frees the proto.
|
||||
// Caller needs to hold the mutex lock while calling this.
|
||||
void Reset();
|
||||
|
||||
const int32 get_size() const;
|
||||
|
||||
const boosted_trees::Tree& decision_tree() const;
|
||||
|
||||
const float get_prediction(const int32 id, const int32 dimension_id) const;
|
||||
|
||||
const int32 TraverseTree(const int32 example_id,
|
||||
const TTypes<float>::ConstMatrix* dense_data) const;
|
||||
|
||||
protected:
|
||||
mutex mu_;
|
||||
protobuf::Arena arena_;
|
||||
boosted_trees::Tree* decision_tree_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
|
79
tensorflow/core/ops/tensor_forest_ops.cc
Normal file
79
tensorflow/core/ops/tensor_forest_ops.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_RESOURCE_HANDLE_OP(TensorForestTreeResource);
|
||||
|
||||
REGISTER_OP("TensorForestTreeIsInitializedOp")
|
||||
.Input("tree_handle: resource")
|
||||
.Output("is_initialized: bool")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused_input;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("TensorForestCreateTreeVariable")
|
||||
.Input("tree_handle: resource")
|
||||
.Input("tree_config: string")
|
||||
.SetShapeFn(tensorflow::shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("TensorForestTreeSerialize")
|
||||
.Input("tree_handle: resource")
|
||||
.Output("tree_config: string")
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("TensorForestTreeDeserialize")
|
||||
.Input("tree_handle: resource")
|
||||
.Input("tree_config: string")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused_input;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("TensorForestTreeSize")
|
||||
.Input("tree_handle: resource")
|
||||
.Output("tree_size: int32")
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("TensorForestTreePredict")
|
||||
.Attr("logits_dimension: int")
|
||||
.Input("tree_handle: resource")
|
||||
.Input("dense_features: float")
|
||||
.Output("logits: float")
|
||||
.SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
shape_inference::DimensionHandle batch_size = c->UnknownDim();
|
||||
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &shape_handle));
|
||||
|
||||
batch_size = c->Dim(shape_handle, 0);
|
||||
|
||||
int logits_dimension;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||
c->set_output(0, c->Matrix(batch_size, logits_dimension));
|
||||
return Status::OK();
|
||||
});
|
||||
} // namespace tensorflow
|
@ -131,6 +131,7 @@ py_library(
|
||||
":subscribe",
|
||||
":summary",
|
||||
":tensor_array_ops",
|
||||
":tensor_forest_ops",
|
||||
":test_ops", # TODO: Break testing code out into separate rule.
|
||||
":tf_cluster",
|
||||
":tf_item",
|
||||
@ -1622,6 +1623,14 @@ tf_gen_op_wrapper_private_py(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "tensor_forest_ops_gen",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core:tensor_forest_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "summary_ops_gen",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
@ -1946,6 +1955,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tensor_forest_ops",
|
||||
srcs = ["ops/tensor_forest_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":ops",
|
||||
":tensor_forest_ops_gen",
|
||||
":training",
|
||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sets",
|
||||
srcs = [
|
||||
|
103
tensorflow/python/ops/tensor_forest_ops.py
Normal file
103
tensorflow/python/ops/tensor_forest_ops.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for tensor_forest."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import ops
|
||||
from tensorflow.python.ops import gen_tensor_forest_ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.training import saver
|
||||
|
||||
|
||||
class TreeVariableSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
"""Resource that holds a tree."""
|
||||
|
||||
def __init__(self, type_name, name, container, config, resource_handle_func,
|
||||
create_op_func, is_initialized_op_func, serialize_op_func,
|
||||
deserialize_op_func):
|
||||
|
||||
with ops.name_scope(name, type_name) as name:
|
||||
self._resource_handle = resource_handle_func(
|
||||
container, shared_name=name, name=name)
|
||||
|
||||
self._is_initialized_op = is_initialized_op_func(self._resource_handle)
|
||||
tensor = serialize_op_func(self._resource_handle)
|
||||
self._create_op = create_op_func(self._resource_handle, config)
|
||||
# slice_spec is useful for saving a slice from a variable.
|
||||
# It's not meaningful the tree variable. So we just pass an empty
|
||||
# value.
|
||||
slice_spec = ''
|
||||
specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)]
|
||||
super(TreeVariableSaveable, self).__init__(self._resource_handle, specs,
|
||||
name)
|
||||
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)
|
||||
|
||||
resources.register_resource(self._resource_handle, self._create_op,
|
||||
self._is_initialized_op)
|
||||
self._deserialize_op_func = deserialize_op_func
|
||||
|
||||
def restore(self, restored_tensors, unused_restored_shapes):
|
||||
"""Restores the associated tree from 'restored_tensors'.
|
||||
|
||||
Args:
|
||||
restored_tensors: the tensors that were loaded from a checkpoint.
|
||||
unused_restored_shapes: the shapes this object should conform to after
|
||||
restore. Not meaningful for trees.
|
||||
|
||||
Returns:
|
||||
The operation that restores the state of the tree variable.
|
||||
"""
|
||||
with ops.control_dependencies([self._create_op]):
|
||||
return self._deserialize_op_func(
|
||||
self._resource_handle,
|
||||
restored_tensors[0],
|
||||
)
|
||||
|
||||
@property
|
||||
def resource(self):
|
||||
return self._resource_handle
|
||||
|
||||
|
||||
def tree_variable(tree_config, name, container=None):
|
||||
return TreeVariableSaveable(
|
||||
'TreeVariable', name, container, tree_config,
|
||||
gen_tensor_forest_ops.tensor_forest_tree_resource_handle_op,
|
||||
gen_tensor_forest_ops.tensor_forest_create_tree_variable,
|
||||
gen_tensor_forest_ops.tensor_forest_tree_is_initialized_op,
|
||||
gen_tensor_forest_ops.tensor_forest_tree_serialize,
|
||||
gen_tensor_forest_ops.tensor_forest_tree_deserialize).resource
|
||||
|
||||
|
||||
class ForestVariables(object):
|
||||
"""Resource that holds all trees from a forest."""
|
||||
|
||||
def __init__(self, params, tree_configs=None):
|
||||
|
||||
self._variables = []
|
||||
|
||||
for i in range(params.n_trees):
|
||||
tree_config = ''
|
||||
if tree_configs is not None:
|
||||
tree_config = tree_configs[i]
|
||||
self._variables.append(tree_variable(
|
||||
tree_config,
|
||||
'tree-%s' % i,
|
||||
))
|
||||
|
||||
def __getitem__(self, t):
|
||||
return self._variables[t]
|
Loading…
Reference in New Issue
Block a user