Merge pull request from yupbank:add-tensor-forest-classification-inference

PiperOrigin-RevId: 222128700
This commit is contained in:
TensorFlower Gardener 2018-11-19 13:39:31 -08:00
commit 68900419ba
19 changed files with 769 additions and 1 deletions

View File

@ -1052,6 +1052,7 @@ tf_gen_op_libs(
"batch_ops",
"bitwise_ops",
"boosted_trees_ops",
"tensor_forest_ops",
"candidate_sampling_ops",
"checkpoint_ops",
"collective_ops",
@ -1201,6 +1202,7 @@ cc_library(
":batch_ops_op_lib",
":bitwise_ops_op_lib",
":boosted_trees_ops_op_lib",
":tensor_forest_ops_op_lib",
":candidate_sampling_ops_op_lib",
":checkpoint_ops_op_lib",
":collective_ops_op_lib",
@ -1354,6 +1356,7 @@ cc_library(
"//tensorflow/core/kernels:batch_kernels",
"//tensorflow/core/kernels:bincount_op",
"//tensorflow/core/kernels:boosted_trees_ops",
"//tensorflow/core/kernels:tensor_forest_ops",
"//tensorflow/core/kernels:candidate_sampler_ops",
"//tensorflow/core/kernels:checkpoint_ops",
"//tensorflow/core/kernels:collective_ops",

View File

@ -0,0 +1,17 @@
op {
graph_op_name: "TensorForestCreateTreeVariable"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource to be created.
END
}
in_arg {
name: "tree_config"
description: <<END
Serialized proto string of the boosted_trees.Tree.
END
}
summary: "Creates a tree resource and returns a handle to it."
}

View File

@ -0,0 +1,17 @@
op {
graph_op_name: "TensorForestTreeDeserialize"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource to be restored.
END
}
in_arg {
name: "tree_config"
description: <<END
Serialied proto string of the boosted_trees.Tree proto.
END
}
summary: "Deserializes a proto into the tree handle"
}

View File

@ -0,0 +1,17 @@
op {
graph_op_name: "TensorForestTreeIsInitializedOp"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree.
END
}
out_arg {
name: "is_initialized"
description: <<END
Whether the tree is initialized.
END
}
summary: "Checks whether a tree has been initialized."
}

View File

@ -0,0 +1,29 @@
op {
graph_op_name: "TensorForestTreePredict"
visibility: HIDDEN
attr {
name: "logits_dimension"
description: <<END
Scalar, dimension of the logits.
END
}
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource.
END
}
in_arg {
name: "dense_features"
description: <<END
Rank 2 dense features tensor.
END
}
out_arg {
name: "logits"
description: <<END
The logits predictions from the tree for each instance in the batch.
END
}
summary: "Output the logits for the given input data"
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "TensorForestTreeResourceHandleOp"
visibility: HIDDEN
summary: "Creates a handle to a TensorForestTreeResource"
}

View File

@ -0,0 +1,17 @@
op {
graph_op_name: "TensorForestTreeSerialize"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource to be serialized.
END
}
out_arg {
name: "tree_config"
description: <<END
Serialied proto string of the tree resource.
END
}
summary: "Serializes the tree handle to a proto"
}

View File

@ -0,0 +1,17 @@
op {
graph_op_name: "TensorForestTreeSize"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource.
END
}
out_arg {
name: "tree_size"
description: <<END
The size of the tree.
END
}
summary: "Get the number of nodes in a tree"
}

View File

@ -6735,6 +6735,13 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "tensor_forest_ops",
deps = [
"//tensorflow/core/kernels/tensor_forest:tensor_forest_ops",
],
)
tf_kernel_library(
name = "dataset_ops",
deps = [

View File

@ -2,7 +2,10 @@
# OpKernels for boosted trees ops.
package(
default_visibility = ["//tensorflow:internal"],
default_visibility = [
"//tensorflow:__subpackages__",
"//tensorflow:internal",
],
)
licenses(["notice"]) # Apache 2.0

View File

@ -12,6 +12,7 @@ message Node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
CategoricalSplit categorical_split = 3;
DenseSplit dense_split = 4;
}
NodeMetadata metadata = 777;
}
@ -70,6 +71,19 @@ message CategoricalSplit {
int32 right_id = 4;
}
// TODO(nponomareva): move out of boosted_trees and rename to trees.proto
message DenseSplit {
// Float feature column and split threshold describing
// the rule feature <= threshold.
int32 feature_id = 1;
float threshold = 2;
// Node children indexing into a contiguous
// vector of nodes starting from the root.
int32 left_id = 3;
int32 right_id = 4;
}
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.

View File

@ -0,0 +1,53 @@
# Description:
# OpKernels for tensor forest ops.
package(
default_visibility = ["//tensorflow:internal"],
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
cc_library(
name = "resources",
srcs = ["resources.cc"],
hdrs = ["resources.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
],
)
tf_kernel_library(
name = "resource_ops",
srcs = ["resource_ops.cc"],
deps = [
":resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensor_forest_ops_op_lib",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
],
)
tf_kernel_library(
name = "prediction_ops",
srcs = ["prediction_ops.cc"],
deps = [
":resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensor_forest_ops_op_lib",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
],
)
tf_kernel_library(
name = "tensor_forest_ops",
deps = [
":prediction_ops",
":resource_ops",
],
)

View File

@ -0,0 +1,93 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/tensor_forest/resources.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
class TensorForestTreePredictOp : public OpKernel {
public:
explicit TensorForestTreePredictOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
context->GetAttr("logits_dimension", &logits_dimension_));
}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* dense_features_t = nullptr;
OP_REQUIRES_OK(context,
context->input("dense_features", &dense_features_t));
auto dense_features = dense_features_t->matrix<float>();
const int32 batch_size = dense_features_t->dim_size(0);
Tensor* output_predictions = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, {batch_size, logits_dimension_},
&output_predictions));
auto out = output_predictions->matrix<float>();
if (decision_tree_resource->get_size() <= 0) {
out.setZero();
return;
}
auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
const int32 num_threads = worker_threads->num_threads;
// TODO(yupbank): This was from contrib version.
// This cost would probably depend on the depth of the tree we have.
// We will need to run it on a number of trees of diff depth
// and see the num of cpu cycles
const int64 cost_per_traverse = 500;
auto traverse = [this, &out, &dense_features, decision_tree_resource,
batch_size](int64 start, int64 end) {
DCHECK_LE(start, end) << "Start exceeding End";
DCHECK_LE(end, batch_size) << "End exceeding batch size";
for (int example_id = start; example_id < end; ++example_id) {
const int32 leaf_id =
decision_tree_resource->TraverseTree(example_id, &dense_features);
set_output_value(example_id, leaf_id, decision_tree_resource, &out);
}
};
Shard(num_threads, worker_threads->workers, batch_size, cost_per_traverse,
traverse);
};
void set_output_value(const int32 example_id, const int32 leaf_id,
const TensorForestTreeResource* decision_tree_resource,
TTypes<float>::Matrix* out) const {
for (int j = 0; j < logits_dimension_; ++j) {
const float logit = decision_tree_resource->get_prediction(leaf_id, j);
(*out)(example_id, j) = logit;
}
}
private:
int32 logits_dimension_;
};
REGISTER_KERNEL_BUILDER(Name("TensorForestTreePredict").Device(DEVICE_CPU),
TensorForestTreePredictOp);
} // namespace tensorflow

View File

@ -0,0 +1,136 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/kernels/tensor_forest/resources.h"
namespace tensorflow {
class TensorForestCreateTreeVariableOp : public OpKernel {
public:
explicit TensorForestCreateTreeVariableOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* tree_config_t;
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
auto* const result = new TensorForestTreeResource();
if (!result->InitFromSerialized(tree_config_t->scalar<string>()())) {
result->Unref();
OP_REQUIRES(context, false,
errors::InvalidArgument("Unable to parse tree config."));
}
// Only create one, if one does not exist already. Report status for all
// other exceptions.
auto status = CreateResource(context, HandleFromInput(context, 0), result);
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
OP_REQUIRES(context, false, status);
}
}
};
// Op for serializing a model.
class TensorForestTreeSerializeOp : public OpKernel {
public:
explicit TensorForestTreeSerializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_config_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape(), &output_config_t));
output_config_t->scalar<string>()() =
decision_tree_resource->decision_tree().SerializeAsString();
}
};
// Op for deserializing a tree variable from a checkpoint.
class TensorForestTreeDeserializeOp : public OpKernel {
public:
explicit TensorForestTreeDeserializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const Tensor* tree_config_t;
OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
// Deallocate all the previous objects on the resource.
decision_tree_resource->Reset();
if (!decision_tree_resource->InitFromSerialized(
tree_config_t->scalar<string>()())) {
OP_REQUIRES(context, false,
errors::InvalidArgument("Unable to parse tree config."));
}
}
};
// Op for getting tree size.
class TensorForestTreeSizeOp : public OpKernel {
public:
explicit TensorForestTreeSizeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
TensorForestTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_resource));
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
Tensor* output_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape(), &output_t));
output_t->scalar<int32>()() = decision_tree_resource->get_size();
}
};
REGISTER_RESOURCE_HANDLE_KERNEL(TensorForestTreeResource);
REGISTER_KERNEL_BUILDER(
Name("TensorForestTreeIsInitializedOp").Device(DEVICE_CPU),
IsResourceInitialized<TensorForestTreeResource>);
REGISTER_KERNEL_BUILDER(
Name("TensorForestCreateTreeVariable").Device(DEVICE_CPU),
TensorForestCreateTreeVariableOp);
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeSerialize").Device(DEVICE_CPU),
TensorForestTreeSerializeOp);
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeDeserialize").Device(DEVICE_CPU),
TensorForestTreeDeserializeOp);
REGISTER_KERNEL_BUILDER(Name("TensorForestTreeSize").Device(DEVICE_CPU),
TensorForestTreeSizeOp);
} // namespace tensorflow

View File

@ -0,0 +1,71 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/tensor_forest/resources.h"
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
const boosted_trees::Tree& TensorForestTreeResource::decision_tree() const {
return *decision_tree_;
}
const int32 TensorForestTreeResource::get_size() const {
return decision_tree_->nodes_size();
}
TensorForestTreeResource::TensorForestTreeResource()
: decision_tree_(
protobuf::Arena::CreateMessage<boosted_trees::Tree>(&arena_)) {}
const float TensorForestTreeResource::get_prediction(
const int32 id, const int32 dimension_id) const {
return decision_tree_->nodes(id).leaf().vector().value(dimension_id);
}
const int32 TensorForestTreeResource::TraverseTree(
const int32 example_id,
const TTypes<float>::ConstMatrix* dense_data) const {
using boosted_trees::Node;
using boosted_trees::Tree;
int32 current_id = 0;
while (true) {
const Node& current = decision_tree_->nodes(current_id);
if (current.has_leaf()) {
return current_id;
}
DCHECK_EQ(current.node_case(), Node::kDenseSplit);
const auto& split = current.dense_split();
if ((*dense_data)(example_id, split.feature_id()) <= split.threshold()) {
current_id = split.left_id();
} else {
current_id = split.right_id();
}
}
}
bool TensorForestTreeResource::InitFromSerialized(const string& serialized) {
return ParseProtoUnlimited(decision_tree_, serialized);
}
void TensorForestTreeResource::Reset() {
arena_.Reset();
DCHECK_EQ(0, arena_.SpaceAllocated());
decision_tree_ = protobuf::Arena::CreateMessage<boosted_trees::Tree>(&arena_);
}
} // namespace tensorflow

View File

@ -0,0 +1,65 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
#define TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
// Forward declaration for proto class Tree.
namespace boosted_trees {
class Tree;
} // namespace boosted_trees
// Keep a tree ensemble in memory for efficient evaluation and mutation.
class TensorForestTreeResource : public ResourceBase {
public:
TensorForestTreeResource();
string DebugString() override {
return strings::StrCat("TensorForestTree[size=", get_size(), "]");
}
mutex* get_mutex() { return &mu_; }
bool InitFromSerialized(const string& serialized);
// Resets the resource and frees the proto.
// Caller needs to hold the mutex lock while calling this.
void Reset();
const int32 get_size() const;
const boosted_trees::Tree& decision_tree() const;
const float get_prediction(const int32 id, const int32 dimension_id) const;
const int32 TraverseTree(const int32 example_id,
const TTypes<float>::ConstMatrix* dense_data) const;
protected:
mutex mu_;
protobuf::Arena arena_;
boosted_trees::Tree* decision_tree_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_TENSOR_FOREST_RESOURCES_H_

View File

@ -0,0 +1,79 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_RESOURCE_HANDLE_OP(TensorForestTreeResource);
REGISTER_OP("TensorForestTreeIsInitializedOp")
.Input("tree_handle: resource")
.Output("is_initialized: bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("TensorForestCreateTreeVariable")
.Input("tree_handle: resource")
.Input("tree_config: string")
.SetShapeFn(tensorflow::shape_inference::NoOutputs);
REGISTER_OP("TensorForestTreeSerialize")
.Input("tree_handle: resource")
.Output("tree_config: string")
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
REGISTER_OP("TensorForestTreeDeserialize")
.Input("tree_handle: resource")
.Input("tree_config: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("TensorForestTreeSize")
.Input("tree_handle: resource")
.Output("tree_size: int32")
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
REGISTER_OP("TensorForestTreePredict")
.Attr("logits_dimension: int")
.Input("tree_handle: resource")
.Input("dense_features: float")
.Output("logits: float")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle shape_handle;
shape_inference::DimensionHandle batch_size = c->UnknownDim();
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &shape_handle));
batch_size = c->Dim(shape_handle, 0);
int logits_dimension;
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
c->set_output(0, c->Matrix(batch_size, logits_dimension));
return Status::OK();
});
} // namespace tensorflow

View File

@ -131,6 +131,7 @@ py_library(
":subscribe",
":summary",
":tensor_array_ops",
":tensor_forest_ops",
":test_ops", # TODO: Break testing code out into separate rule.
":tf_cluster",
":tf_item",
@ -1622,6 +1623,14 @@ tf_gen_op_wrapper_private_py(
],
)
tf_gen_op_wrapper_private_py(
name = "tensor_forest_ops_gen",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:tensor_forest_ops_op_lib",
],
)
tf_gen_op_wrapper_private_py(
name = "summary_ops_gen",
visibility = ["//tensorflow:__subpackages__"],
@ -1946,6 +1955,19 @@ py_library(
],
)
py_library(
name = "tensor_forest_ops",
srcs = ["ops/tensor_forest_ops.py"],
srcs_version = "PY2AND3",
deps = [
":framework",
":ops",
":tensor_forest_ops_gen",
":training",
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
],
)
py_library(
name = "sets",
srcs = [

View File

@ -0,0 +1,103 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for tensor_forest."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import ops
from tensorflow.python.ops import gen_tensor_forest_ops
from tensorflow.python.ops import resources
from tensorflow.python.training import saver
class TreeVariableSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Resource that holds a tree."""
def __init__(self, type_name, name, container, config, resource_handle_func,
create_op_func, is_initialized_op_func, serialize_op_func,
deserialize_op_func):
with ops.name_scope(name, type_name) as name:
self._resource_handle = resource_handle_func(
container, shared_name=name, name=name)
self._is_initialized_op = is_initialized_op_func(self._resource_handle)
tensor = serialize_op_func(self._resource_handle)
self._create_op = create_op_func(self._resource_handle, config)
# slice_spec is useful for saving a slice from a variable.
# It's not meaningful the tree variable. So we just pass an empty
# value.
slice_spec = ''
specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)]
super(TreeVariableSaveable, self).__init__(self._resource_handle, specs,
name)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)
resources.register_resource(self._resource_handle, self._create_op,
self._is_initialized_op)
self._deserialize_op_func = deserialize_op_func
def restore(self, restored_tensors, unused_restored_shapes):
"""Restores the associated tree from 'restored_tensors'.
Args:
restored_tensors: the tensors that were loaded from a checkpoint.
unused_restored_shapes: the shapes this object should conform to after
restore. Not meaningful for trees.
Returns:
The operation that restores the state of the tree variable.
"""
with ops.control_dependencies([self._create_op]):
return self._deserialize_op_func(
self._resource_handle,
restored_tensors[0],
)
@property
def resource(self):
return self._resource_handle
def tree_variable(tree_config, name, container=None):
return TreeVariableSaveable(
'TreeVariable', name, container, tree_config,
gen_tensor_forest_ops.tensor_forest_tree_resource_handle_op,
gen_tensor_forest_ops.tensor_forest_create_tree_variable,
gen_tensor_forest_ops.tensor_forest_tree_is_initialized_op,
gen_tensor_forest_ops.tensor_forest_tree_serialize,
gen_tensor_forest_ops.tensor_forest_tree_deserialize).resource
class ForestVariables(object):
"""Resource that holds all trees from a forest."""
def __init__(self, params, tree_configs=None):
self._variables = []
for i in range(params.n_trees):
tree_config = ''
if tree_configs is not None:
tree_config = tree_configs[i]
self._variables.append(tree_variable(
tree_config,
'tree-%s' % i,
))
def __getitem__(self, t):
return self._variables[t]