Merge pull request #21803 from yupbank:add-tensor-forest-classification-inference
PiperOrigin-RevId: 222128700
This commit is contained in:
commit
68900419ba
@ -1052,6 +1052,7 @@ tf_gen_op_libs(
|
|||||||
"batch_ops",
|
"batch_ops",
|
||||||
"bitwise_ops",
|
"bitwise_ops",
|
||||||
"boosted_trees_ops",
|
"boosted_trees_ops",
|
||||||
|
"tensor_forest_ops",
|
||||||
"candidate_sampling_ops",
|
"candidate_sampling_ops",
|
||||||
"checkpoint_ops",
|
"checkpoint_ops",
|
||||||
"collective_ops",
|
"collective_ops",
|
||||||
@ -1201,6 +1202,7 @@ cc_library(
|
|||||||
":batch_ops_op_lib",
|
":batch_ops_op_lib",
|
||||||
":bitwise_ops_op_lib",
|
":bitwise_ops_op_lib",
|
||||||
":boosted_trees_ops_op_lib",
|
":boosted_trees_ops_op_lib",
|
||||||
|
":tensor_forest_ops_op_lib",
|
||||||
":candidate_sampling_ops_op_lib",
|
":candidate_sampling_ops_op_lib",
|
||||||
":checkpoint_ops_op_lib",
|
":checkpoint_ops_op_lib",
|
||||||
":collective_ops_op_lib",
|
":collective_ops_op_lib",
|
||||||
@ -1354,6 +1356,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:batch_kernels",
|
"//tensorflow/core/kernels:batch_kernels",
|
||||||
"//tensorflow/core/kernels:bincount_op",
|
"//tensorflow/core/kernels:bincount_op",
|
||||||
"//tensorflow/core/kernels:boosted_trees_ops",
|
"//tensorflow/core/kernels:boosted_trees_ops",
|
||||||
|
"//tensorflow/core/kernels:tensor_forest_ops",
|
||||||
"//tensorflow/core/kernels:candidate_sampler_ops",
|
"//tensorflow/core/kernels:candidate_sampler_ops",
|
||||||
"//tensorflow/core/kernels:checkpoint_ops",
|
"//tensorflow/core/kernels:checkpoint_ops",
|
||||||
"//tensorflow/core/kernels:collective_ops",
|
"//tensorflow/core/kernels:collective_ops",
|
||||||
|
|||||||
@ -0,0 +1,17 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TensorForestCreateTreeVariable"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "tree_handle"
|
||||||
|
description: <<END
|
||||||
|
Handle to the tree resource to be created.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "tree_config"
|
||||||
|
description: <<END
|
||||||
|
Serialized proto string of the boosted_trees.Tree.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Creates a tree resource and returns a handle to it."
|
||||||
|
}
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TensorForestTreeDeserialize"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "tree_handle"
|
||||||
|
description: <<END
|
||||||
|
Handle to the tree resource to be restored.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "tree_config"
|
||||||
|
description: <<END
|
||||||
|
Serialied proto string of the boosted_trees.Tree proto.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Deserializes a proto into the tree handle"
|
||||||
|
}
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TensorForestTreeIsInitializedOp"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "tree_handle"
|
||||||
|
description: <<END
|
||||||
|
Handle to the tree.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "is_initialized"
|
||||||
|
description: <<END
|
||||||
|
Whether the tree is initialized.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Checks whether a tree has been initialized."
|
||||||
|
}
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TensorForestTreePredict"
|
||||||
|
visibility: HIDDEN
|
||||||
|
attr {
|
||||||
|
name: "logits_dimension"
|
||||||
|
description: <<END
|
||||||
|
Scalar, dimension of the logits.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "tree_handle"
|
||||||
|
description: <<END
|
||||||
|
Handle to the tree resource.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "dense_features"
|
||||||
|
description: <<END
|
||||||
|
Rank 2 dense features tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "logits"
|
||||||
|
description: <<END
|
||||||
|
The logits predictions from the tree for each instance in the batch.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Output the logits for the given input data"
|
||||||
|
}
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TensorForestTreeResourceHandleOp"
|
||||||
|
visibility: HIDDEN
|
||||||
|
summary: "Creates a handle to a TensorForestTreeResource"
|
||||||
|
}
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TensorForestTreeSerialize"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "tree_handle"
|
||||||
|
description: <<END
|
||||||
|
Handle to the tree resource to be serialized.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "tree_config"
|
||||||
|
description: <<END
|
||||||
|
Serialied proto string of the tree resource.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Serializes the tree handle to a proto"
|
||||||
|
}
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "TensorForestTreeSize"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "tree_handle"
|
||||||
|
description: <<END
|
||||||
|
Handle to the tree resource.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "tree_size"
|
||||||
|
description: <<END
|
||||||
|
The size of the tree.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Get the number of nodes in a tree"
|
||||||
|
}
|
||||||
@ -6735,6 +6735,13 @@ tf_kernel_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "tensor_forest_ops",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core/kernels/tensor_forest:tensor_forest_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "dataset_ops",
|
name = "dataset_ops",
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -2,7 +2,10 @@
|
|||||||
# OpKernels for boosted trees ops.
|
# OpKernels for boosted trees ops.
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:internal"],
|
default_visibility = [
|
||||||
|
"//tensorflow:__subpackages__",
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|||||||
@ -12,6 +12,7 @@ message Node {
|
|||||||
Leaf leaf = 1;
|
Leaf leaf = 1;
|
||||||
BucketizedSplit bucketized_split = 2;
|
BucketizedSplit bucketized_split = 2;
|
||||||
CategoricalSplit categorical_split = 3;
|
CategoricalSplit categorical_split = 3;
|
||||||
|
DenseSplit dense_split = 4;
|
||||||
}
|
}
|
||||||
NodeMetadata metadata = 777;
|
NodeMetadata metadata = 777;
|
||||||
}
|
}
|
||||||
@ -70,6 +71,19 @@ message CategoricalSplit {
|
|||||||
int32 right_id = 4;
|
int32 right_id = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(nponomareva): move out of boosted_trees and rename to trees.proto
|
||||||
|
message DenseSplit {
|
||||||
|
// Float feature column and split threshold describing
|
||||||
|
// the rule feature <= threshold.
|
||||||
|
int32 feature_id = 1;
|
||||||
|
float threshold = 2;
|
||||||
|
|
||||||
|
// Node children indexing into a contiguous
|
||||||
|
// vector of nodes starting from the root.
|
||||||
|
int32 left_id = 3;
|
||||||
|
int32 right_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
// Tree describes a list of connected nodes.
|
// Tree describes a list of connected nodes.
|
||||||
// Node 0 must be the root and can carry any payload including a leaf
|
// Node 0 must be the root and can carry any payload including a leaf
|
||||||
// in the case of representing the bias.
|
// in the case of representing the bias.
|
||||||
|
|||||||
53
tensorflow/core/kernels/tensor_forest/BUILD
Normal file
53
tensorflow/core/kernels/tensor_forest/BUILD
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Description:
|
||||||
|
# OpKernels for tensor forest ops.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//tensorflow:internal"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "resources",
|
||||||
|
srcs = ["resources.cc"],
|
||||||
|
hdrs = ["resources.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "resource_ops",
|
||||||
|
srcs = ["resource_ops.cc"],
|
||||||
|
deps = [
|
||||||
|
":resources",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:tensor_forest_ops_op_lib",
|
||||||
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "prediction_ops",
|
||||||
|
srcs = ["prediction_ops.cc"],
|
||||||
|
deps = [
|
||||||
|
":resources",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:tensor_forest_ops_op_lib",
|
||||||
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "tensor_forest_ops",
|
||||||
|
deps = [
|
||||||
|
":prediction_ops",
|
||||||
|
":resource_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
93
tensorflow/core/kernels/tensor_forest/prediction_ops.cc
Normal file
93
tensorflow/core/kernels/tensor_forest/prediction_ops.cc
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||||
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class TensorForestTreePredictOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit TensorForestTreePredictOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("logits_dimension", &logits_dimension_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
TensorForestTreeResource* decision_tree_resource;
|
||||||
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
|
&decision_tree_resource));
|
||||||
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
|
core::ScopedUnref unref_me(decision_tree_resource);
|
||||||
|
|
||||||
|
const Tensor* dense_features_t = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->input("dense_features", &dense_features_t));
|
||||||
|
|
||||||
|
auto dense_features = dense_features_t->matrix<float>();
|
||||||
|
const int32 batch_size = dense_features_t->dim_size(0);
|
||||||
|
|
||||||
|
Tensor* output_predictions = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, {batch_size, logits_dimension_},
|
||||||
|
&output_predictions));
|
||||||
|
auto out = output_predictions->matrix<float>();
|
||||||
|
|
||||||
|
if (decision_tree_resource->get_size() <= 0) {
|
||||||
|
out.setZero();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||||
|
const int32 num_threads = worker_threads->num_threads;
|
||||||
|
|
||||||
|
// TODO(yupbank): This was from contrib version.
|
||||||
|
// This cost would probably depend on the depth of the tree we have.
|
||||||
|
// We will need to run it on a number of trees of diff depth
|
||||||
|
// and see the num of cpu cycles
|
||||||
|
const int64 cost_per_traverse = 500;
|
||||||
|
auto traverse = [this, &out, &dense_features, decision_tree_resource,
|
||||||
|
batch_size](int64 start, int64 end) {
|
||||||
|
DCHECK_LE(start, end) << "Start exceeding End";
|
||||||
|
DCHECK_LE(end, batch_size) << "End exceeding batch size";
|
||||||
|
for (int example_id = start; example_id < end; ++example_id) {
|
||||||
|
const int32 leaf_id =
|
||||||
|
decision_tree_resource->TraverseTree(example_id, &dense_features);
|
||||||
|
set_output_value(example_id, leaf_id, decision_tree_resource, &out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Shard(num_threads, worker_threads->workers, batch_size, cost_per_traverse,
|
||||||
|
traverse);
|
||||||
|
};
|
||||||
|
|
||||||
|
void set_output_value(const int32 example_id, const int32 leaf_id,
|
||||||
|
const TensorForestTreeResource* decision_tree_resource,
|
||||||
|
TTypes<float>::Matrix* out) const {
|
||||||
|
for (int j = 0; j < logits_dimension_; ++j) {
|
||||||
|
const float logit = decision_tree_resource->get_prediction(leaf_id, j);
|
||||||
|
(*out)(example_id, j) = logit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32 logits_dimension_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TensorForestTreePredict").Device(DEVICE_CPU),
|
||||||
|
TensorForestTreePredictOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
136
tensorflow/core/kernels/tensor_forest/resource_ops.cc
Normal file
136
tensorflow/core/kernels/tensor_forest/resource_ops.cc
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class TensorForestCreateTreeVariableOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit TensorForestCreateTreeVariableOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor* tree_config_t;
|
||||||
|
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||||
|
|
||||||
|
auto* const result = new TensorForestTreeResource();
|
||||||
|
|
||||||
|
if (!result->InitFromSerialized(tree_config_t->scalar<string>()())) {
|
||||||
|
result->Unref();
|
||||||
|
OP_REQUIRES(context, false,
|
||||||
|
errors::InvalidArgument("Unable to parse tree config."));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only create one, if one does not exist already. Report status for all
|
||||||
|
// other exceptions.
|
||||||
|
auto status = CreateResource(context, HandleFromInput(context, 0), result);
|
||||||
|
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||||
|
OP_REQUIRES(context, false, status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Op for serializing a model.
|
||||||
|
class TensorForestTreeSerializeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit TensorForestTreeSerializeOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
TensorForestTreeResource* decision_tree_resource;
|
||||||
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
|
&decision_tree_resource));
|
||||||
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
|
core::ScopedUnref unref_me(decision_tree_resource);
|
||||||
|
Tensor* output_config_t = nullptr;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->allocate_output(0, TensorShape(), &output_config_t));
|
||||||
|
output_config_t->scalar<string>()() =
|
||||||
|
decision_tree_resource->decision_tree().SerializeAsString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Op for deserializing a tree variable from a checkpoint.
|
||||||
|
class TensorForestTreeDeserializeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {}
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
TensorForestTreeResource* decision_tree_resource;
|
||||||
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
|
&decision_tree_resource));
|
||||||
|
|
||||||
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
|
core::ScopedUnref unref_me(decision_tree_resource);
|
||||||
|
|
||||||
|
const Tensor* tree_config_t;
|
||||||
|
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
|
||||||
|
|
||||||
|
// Deallocate all the previous objects on the resource.
|
||||||
|
decision_tree_resource->Reset();
|
||||||
|
|
||||||
|
if (!decision_tree_resource->InitFromSerialized(
|
||||||
|
tree_config_t->scalar<string>()())) {
|
||||||
|
OP_REQUIRES(context, false,
|
||||||
|
errors::InvalidArgument("Unable to parse tree config."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Op for getting tree size.
|
||||||
|
class TensorForestTreeSizeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit TensorForestTreeSizeOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
TensorForestTreeResource* decision_tree_resource;
|
||||||
|
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||||
|
&decision_tree_resource));
|
||||||
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
|
core::ScopedUnref unref_me(decision_tree_resource);
|
||||||
|
Tensor* output_t = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, TensorShape(), &output_t));
|
||||||
|
output_t->scalar<int32>()() = decision_tree_resource->get_size();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_RESOURCE_HANDLE_KERNEL(TensorForestTreeResource);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("TensorForestTreeIsInitializedOp").Device(DEVICE_CPU),
|
||||||
|
IsResourceInitialized<TensorForestTreeResource>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("TensorForestCreateTreeVariable").Device(DEVICE_CPU),
|
||||||
|
TensorForestCreateTreeVariableOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeSerialize").Device(DEVICE_CPU),
|
||||||
|
TensorForestTreeSerializeOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeDeserialize").Device(DEVICE_CPU),
|
||||||
|
TensorForestTreeDeserializeOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeSize").Device(DEVICE_CPU),
|
||||||
|
TensorForestTreeSizeOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
71
tensorflow/core/kernels/tensor_forest/resources.cc
Normal file
71
tensorflow/core/kernels/tensor_forest/resources.cc
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/tensor_forest/resources.h"
|
||||||
|
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
const boosted_trees::Tree& TensorForestTreeResource::decision_tree() const {
|
||||||
|
return *decision_tree_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32 TensorForestTreeResource::get_size() const {
|
||||||
|
return decision_tree_->nodes_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorForestTreeResource::TensorForestTreeResource()
|
||||||
|
: decision_tree_(
|
||||||
|
protobuf::Arena::CreateMessage<boosted_trees::Tree>(&arena_)) {}
|
||||||
|
|
||||||
|
const float TensorForestTreeResource::get_prediction(
|
||||||
|
const int32 id, const int32 dimension_id) const {
|
||||||
|
return decision_tree_->nodes(id).leaf().vector().value(dimension_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32 TensorForestTreeResource::TraverseTree(
|
||||||
|
const int32 example_id,
|
||||||
|
const TTypes<float>::ConstMatrix* dense_data) const {
|
||||||
|
using boosted_trees::Node;
|
||||||
|
using boosted_trees::Tree;
|
||||||
|
int32 current_id = 0;
|
||||||
|
while (true) {
|
||||||
|
const Node& current = decision_tree_->nodes(current_id);
|
||||||
|
if (current.has_leaf()) {
|
||||||
|
return current_id;
|
||||||
|
}
|
||||||
|
DCHECK_EQ(current.node_case(), Node::kDenseSplit);
|
||||||
|
const auto& split = current.dense_split();
|
||||||
|
|
||||||
|
if ((*dense_data)(example_id, split.feature_id()) <= split.threshold()) {
|
||||||
|
current_id = split.left_id();
|
||||||
|
} else {
|
||||||
|
current_id = split.right_id();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TensorForestTreeResource::InitFromSerialized(const string& serialized) {
|
||||||
|
return ParseProtoUnlimited(decision_tree_, serialized);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorForestTreeResource::Reset() {
|
||||||
|
arena_.Reset();
|
||||||
|
DCHECK_EQ(0, arena_.SpaceAllocated());
|
||||||
|
decision_tree_ = protobuf::Arena::CreateMessage<boosted_trees::Tree>(&arena_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
65
tensorflow/core/kernels/tensor_forest/resources.h
Normal file
65
tensorflow/core/kernels/tensor_forest/resources.h
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Forward declaration for proto class Tree.
|
||||||
|
namespace boosted_trees {
|
||||||
|
class Tree;
|
||||||
|
} // namespace boosted_trees
|
||||||
|
|
||||||
|
// Keep a tree ensemble in memory for efficient evaluation and mutation.
|
||||||
|
class TensorForestTreeResource : public ResourceBase {
|
||||||
|
public:
|
||||||
|
TensorForestTreeResource();
|
||||||
|
|
||||||
|
string DebugString() override {
|
||||||
|
return strings::StrCat("TensorForestTree[size=", get_size(), "]");
|
||||||
|
}
|
||||||
|
|
||||||
|
mutex* get_mutex() { return &mu_; }
|
||||||
|
|
||||||
|
bool InitFromSerialized(const string& serialized);
|
||||||
|
|
||||||
|
// Resets the resource and frees the proto.
|
||||||
|
// Caller needs to hold the mutex lock while calling this.
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
const int32 get_size() const;
|
||||||
|
|
||||||
|
const boosted_trees::Tree& decision_tree() const;
|
||||||
|
|
||||||
|
const float get_prediction(const int32 id, const int32 dimension_id) const;
|
||||||
|
|
||||||
|
const int32 TraverseTree(const int32 example_id,
|
||||||
|
const TTypes<float>::ConstMatrix* dense_data) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
mutex mu_;
|
||||||
|
protobuf::Arena arena_;
|
||||||
|
boosted_trees::Tree* decision_tree_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
|
||||||
79
tensorflow/core/ops/tensor_forest_ops.cc
Normal file
79
tensorflow/core/ops/tensor_forest_ops.cc
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_RESOURCE_HANDLE_OP(TensorForestTreeResource);
|
||||||
|
|
||||||
|
REGISTER_OP("TensorForestTreeIsInitializedOp")
|
||||||
|
.Input("tree_handle: resource")
|
||||||
|
.Output("is_initialized: bool")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle unused_input;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("TensorForestCreateTreeVariable")
|
||||||
|
.Input("tree_handle: resource")
|
||||||
|
.Input("tree_config: string")
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::NoOutputs);
|
||||||
|
|
||||||
|
REGISTER_OP("TensorForestTreeSerialize")
|
||||||
|
.Input("tree_handle: resource")
|
||||||
|
.Output("tree_config: string")
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
REGISTER_OP("TensorForestTreeDeserialize")
|
||||||
|
.Input("tree_handle: resource")
|
||||||
|
.Input("tree_config: string")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle unused_input;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("TensorForestTreeSize")
|
||||||
|
.Input("tree_handle: resource")
|
||||||
|
.Output("tree_size: int32")
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
REGISTER_OP("TensorForestTreePredict")
|
||||||
|
.Attr("logits_dimension: int")
|
||||||
|
.Input("tree_handle: resource")
|
||||||
|
.Input("dense_features: float")
|
||||||
|
.Output("logits: float")
|
||||||
|
.SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle shape_handle;
|
||||||
|
shape_inference::DimensionHandle batch_size = c->UnknownDim();
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &shape_handle));
|
||||||
|
|
||||||
|
batch_size = c->Dim(shape_handle, 0);
|
||||||
|
|
||||||
|
int logits_dimension;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||||
|
c->set_output(0, c->Matrix(batch_size, logits_dimension));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
} // namespace tensorflow
|
||||||
@ -131,6 +131,7 @@ py_library(
|
|||||||
":subscribe",
|
":subscribe",
|
||||||
":summary",
|
":summary",
|
||||||
":tensor_array_ops",
|
":tensor_array_ops",
|
||||||
|
":tensor_forest_ops",
|
||||||
":test_ops", # TODO: Break testing code out into separate rule.
|
":test_ops", # TODO: Break testing code out into separate rule.
|
||||||
":tf_cluster",
|
":tf_cluster",
|
||||||
":tf_item",
|
":tf_item",
|
||||||
@ -1622,6 +1623,14 @@ tf_gen_op_wrapper_private_py(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_private_py(
|
||||||
|
name = "tensor_forest_ops_gen",
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:tensor_forest_ops_op_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
name = "summary_ops_gen",
|
name = "summary_ops_gen",
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
@ -1946,6 +1955,19 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tensor_forest_ops",
|
||||||
|
srcs = ["ops/tensor_forest_ops.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":framework",
|
||||||
|
":ops",
|
||||||
|
":tensor_forest_ops_gen",
|
||||||
|
":training",
|
||||||
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "sets",
|
name = "sets",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|||||||
103
tensorflow/python/ops/tensor_forest_ops.py
Normal file
103
tensorflow/python/ops/tensor_forest_ops.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Ops for tensor_forest."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python import ops
|
||||||
|
from tensorflow.python.ops import gen_tensor_forest_ops
|
||||||
|
from tensorflow.python.ops import resources
|
||||||
|
from tensorflow.python.training import saver
|
||||||
|
|
||||||
|
|
||||||
|
class TreeVariableSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||||
|
"""Resource that holds a tree."""
|
||||||
|
|
||||||
|
def __init__(self, type_name, name, container, config, resource_handle_func,
|
||||||
|
create_op_func, is_initialized_op_func, serialize_op_func,
|
||||||
|
deserialize_op_func):
|
||||||
|
|
||||||
|
with ops.name_scope(name, type_name) as name:
|
||||||
|
self._resource_handle = resource_handle_func(
|
||||||
|
container, shared_name=name, name=name)
|
||||||
|
|
||||||
|
self._is_initialized_op = is_initialized_op_func(self._resource_handle)
|
||||||
|
tensor = serialize_op_func(self._resource_handle)
|
||||||
|
self._create_op = create_op_func(self._resource_handle, config)
|
||||||
|
# slice_spec is useful for saving a slice from a variable.
|
||||||
|
# It's not meaningful the tree variable. So we just pass an empty
|
||||||
|
# value.
|
||||||
|
slice_spec = ''
|
||||||
|
specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)]
|
||||||
|
super(TreeVariableSaveable, self).__init__(self._resource_handle, specs,
|
||||||
|
name)
|
||||||
|
|
||||||
|
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)
|
||||||
|
|
||||||
|
resources.register_resource(self._resource_handle, self._create_op,
|
||||||
|
self._is_initialized_op)
|
||||||
|
self._deserialize_op_func = deserialize_op_func
|
||||||
|
|
||||||
|
def restore(self, restored_tensors, unused_restored_shapes):
|
||||||
|
"""Restores the associated tree from 'restored_tensors'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
restored_tensors: the tensors that were loaded from a checkpoint.
|
||||||
|
unused_restored_shapes: the shapes this object should conform to after
|
||||||
|
restore. Not meaningful for trees.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The operation that restores the state of the tree variable.
|
||||||
|
"""
|
||||||
|
with ops.control_dependencies([self._create_op]):
|
||||||
|
return self._deserialize_op_func(
|
||||||
|
self._resource_handle,
|
||||||
|
restored_tensors[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resource(self):
|
||||||
|
return self._resource_handle
|
||||||
|
|
||||||
|
|
||||||
|
def tree_variable(tree_config, name, container=None):
|
||||||
|
return TreeVariableSaveable(
|
||||||
|
'TreeVariable', name, container, tree_config,
|
||||||
|
gen_tensor_forest_ops.tensor_forest_tree_resource_handle_op,
|
||||||
|
gen_tensor_forest_ops.tensor_forest_create_tree_variable,
|
||||||
|
gen_tensor_forest_ops.tensor_forest_tree_is_initialized_op,
|
||||||
|
gen_tensor_forest_ops.tensor_forest_tree_serialize,
|
||||||
|
gen_tensor_forest_ops.tensor_forest_tree_deserialize).resource
|
||||||
|
|
||||||
|
|
||||||
|
class ForestVariables(object):
|
||||||
|
"""Resource that holds all trees from a forest."""
|
||||||
|
|
||||||
|
def __init__(self, params, tree_configs=None):
|
||||||
|
|
||||||
|
self._variables = []
|
||||||
|
|
||||||
|
for i in range(params.n_trees):
|
||||||
|
tree_config = ''
|
||||||
|
if tree_configs is not None:
|
||||||
|
tree_config = tree_configs[i]
|
||||||
|
self._variables.append(tree_variable(
|
||||||
|
tree_config,
|
||||||
|
'tree-%s' % i,
|
||||||
|
))
|
||||||
|
|
||||||
|
def __getitem__(self, t):
|
||||||
|
return self._variables[t]
|
||||||
Loading…
x
Reference in New Issue
Block a user