BoostedTrees support float feature columns.

PiperOrigin-RevId: 220567110
This commit is contained in:
Zhenyu Tan 2018-11-07 18:52:04 -08:00 committed by TensorFlower Gardener
parent ef6e793e21
commit bb878129e1
11 changed files with 274 additions and 43 deletions

View File

@ -4,7 +4,7 @@ op {
in_arg {
name: "float_values"
description: <<END
float; List of Rank 2 Tensor each containing float values for a single feature.
float; List of Rank 1 Tensor each containing float values for a single feature.
END
}
in_arg {
@ -17,7 +17,7 @@ END
out_arg {
name: "buckets"
description: <<END
int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
int; List of Rank 1 Tensors each containing the bucketized values for a single feature.
END
}
attr {

View File

@ -4,7 +4,7 @@ op {
in_arg {
name: "float_values"
description: <<END
float; List of Rank 2 Tensors each containing values for a single feature.
float; List of Rank 1 Tensors each containing values for a single feature.
END
}
in_arg {
@ -22,8 +22,8 @@ END
out_arg {
name: "summaries"
description: <<END
float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
min_rank, max_rank) of a single feature.
float; List of Rank 2 Tensors each containing the quantile summary
(value, weight, min_rank, max_rank) of a single feature.
END
}
attr {
@ -35,6 +35,7 @@ END
}
summary: "Makes the summary of quantiles for the batch."
description: <<END
An op that takes a list of tensors and outputs the quantile summaries for each tensor.
An op that takes a list of tensors (one tensor per feature) and outputs the
quantile summaries for each tensor.
END
}

View File

@ -0,0 +1,26 @@
op {
graph_op_name: "BoostedTreesQuantileStreamResourceDeserialize"
visibility: HIDDEN
in_arg {
name: "quantile_stream_resource_handle"
description: <<END
resource handle referring to a QuantileStreamResource.
END
}
in_arg {
name: "bucket_boundaries"
description: <<END
float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
END
}
attr {
name: "num_streams"
description: <<END
inferred int; number of features to get bucket boundaries for.
END
}
summary: "Deserialize bucket boundaries and ready flag into current QuantileAccumulator."
description: <<END
An op that deserializes bucket boundaries and are boundaries ready flag into current QuantileAccumulator.
END
}

View File

@ -29,6 +29,7 @@
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
@ -151,8 +152,14 @@ class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
context->input(kExampleWeightsName, &example_weights_t));
DCHECK(float_features_list.size() > 0) << "Got empty feature list";
auto example_weights = example_weights_t->flat<float>();
const int64 batch_size = example_weights.size();
const int64 weight_size = example_weights.size();
const int64 batch_size = float_features_list[0].flat<float>().size();
OP_REQUIRES(
context, weight_size == 1 || weight_size == batch_size,
errors::InvalidArgument(strings::Printf(
"Weights should be a single value or same size as features.")));
const Tensor* epsilon_t;
OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
float epsilon = epsilon_t->scalar<float>()();
@ -168,7 +175,9 @@ class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
QuantileStream stream(epsilon, batch_size + 1);
// Run quantile summary generation.
for (int64 j = 0; j < batch_size; j++) {
stream.PushEntry(feature_values(j), example_weights(j));
stream.PushEntry(feature_values(j), (weight_size > 1)
? example_weights(j)
: example_weights(0));
}
stream.Finalize();
const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
@ -263,6 +272,57 @@ REGISTER_KERNEL_BUILDER(
Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
BoostedTreesQuantileStreamResourceAddSummariesOp);
class BoostedTreesQuantileStreamResourceDeserializeOp : public OpKernel {
public:
explicit BoostedTreesQuantileStreamResourceDeserializeOp(
OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr(kNumStreamsName, &num_features_));
}
void Compute(OpKernelContext* context) override {
QuantileStreamResource* streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
core::ScopedUnref unref_me(streams_resource);
OpInputList bucket_boundaries_list;
OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
&bucket_boundaries_list));
auto do_quantile_deserialize = [&](const int64 begin, const int64 end) {
// Iterating over all streams.
for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
const Tensor& bucket_boundaries_t = bucket_boundaries_list[stream_idx];
const auto& bucket_boundaries = bucket_boundaries_t.vec<float>();
std::vector<float> result;
result.reserve(bucket_boundaries.size());
for (size_t i = 0; i < bucket_boundaries.size(); ++i) {
result.push_back(bucket_boundaries(i));
}
streams_resource->set_boundaries(result, stream_idx);
}
};
// TODO(tanzheny): comment on the magic number.
const int64 kCostPerUnit = 500 * num_features_;
const DeviceBase::CpuWorkerThreads& worker_threads =
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
kCostPerUnit, do_quantile_deserialize);
}
private:
int64 num_features_;
};
REGISTER_KERNEL_BUILDER(
Name("BoostedTreesQuantileStreamResourceDeserialize").Device(DEVICE_CPU),
BoostedTreesQuantileStreamResourceDeserializeOp);
class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
public:
explicit BoostedTreesQuantileStreamResourceFlushOp(
@ -409,28 +469,29 @@ class BoostedTreesBucketizeOp : public OpKernel {
const int64 num_values = values_tensor.dim_size(0);
Tensor* output_t = nullptr;
OP_REQUIRES_OK(
context, buckets_list.allocate(
feature_idx, TensorShape({num_values, 1}), &output_t));
auto output = output_t->matrix<int32>();
OP_REQUIRES_OK(context,
buckets_list.allocate(
feature_idx, TensorShape({num_values}), &output_t));
auto output = output_t->flat<int32>();
const std::vector<float>& bucket_boundaries_vector =
GetBuckets(feature_idx, bucket_boundaries_list);
CHECK(!bucket_boundaries_vector.empty())
<< "Got empty buckets for feature " << feature_idx;
auto flat_values = values_tensor.flat<float>();
const auto& iter_begin = bucket_boundaries_vector.begin();
const auto& iter_end = bucket_boundaries_vector.end();
for (int64 instance = 0; instance < num_values; instance++) {
if (iter_begin == iter_end) {
output(instance) = 0;
continue;
}
const float value = flat_values(instance);
auto bucket_iter =
std::lower_bound(bucket_boundaries_vector.begin(),
bucket_boundaries_vector.end(), value);
if (bucket_iter == bucket_boundaries_vector.end()) {
auto bucket_iter = std::lower_bound(iter_begin, iter_end, value);
if (bucket_iter == iter_end) {
--bucket_iter;
}
const int32 bucket = static_cast<int32>(
bucket_iter - bucket_boundaries_vector.begin());
const int32 bucket = static_cast<int32>(bucket_iter - iter_begin);
// Bucket id.
output(instance, 0) = bucket;
output(instance) = bucket;
}
}
};

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@ -400,10 +401,7 @@ REGISTER_OP("BoostedTreesMakeQuantileSummaries")
for (int i = 0; i < num_features; ++i) {
ShapeHandle feature_shape;
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
c->Dim(example_weights_shape, 0),
&unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
// the columns are value, weight, min_rank, max_rank.
c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
}
@ -431,6 +429,17 @@ REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
return Status::OK();
});
REGISTER_OP("BoostedTreesQuantileStreamResourceDeserialize")
.Attr("num_streams: int")
.Input("quantile_stream_resource_handle: resource")
.Input("bucket_boundaries: num_streams * float")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
.Attr("generate_quantiles: bool = False")
.Input("quantile_stream_resource_handle: resource")
@ -470,13 +479,13 @@ REGISTER_OP("BoostedTreesBucketize")
ShapeHandle feature_shape;
DimensionHandle unused_dim;
for (int i = 0; i < num_features; i++) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
c->Dim(c->input(0), 0), &unused_dim));
}
// Bucketized result should have same dimension as input.
for (int i = 0; i < num_features; i++) {
c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1}));
c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0)}));
}
return Status::OK();
});

View File

@ -18,14 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import boosted_trees_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as resource_handle_op
from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as resource_initialized
from tensorflow.python.platform import googletest
from tensorflow.python.training import saver
class QuantileOpsTest(test_util.TensorFlowTestCase):
@ -57,18 +64,16 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
| 5 | 1 | 2.2 | 0.8
"""
self._feature_0 = constant_op.constant(
[[1.2], [12.1], [0.3], [0.5], [0.6], [2.2]], dtype=dtypes.float32)
self._feature_1 = constant_op.constant(
[[2.3], [1.2], [1.1], [2.6], [3.2], [0.8]], dtype=dtypes.float32)
self._feature_0_boundaries = constant_op.constant(
[0.3, 0.6, 1.2, 12.1], dtype=dtypes.float32)
self._feature_1_boundaries = constant_op.constant(
[0.8, 1.2, 2.3, 3.2], dtype=dtypes.float32)
self._feature_0_quantiles = constant_op.constant(
[[2], [3], [0], [1], [1], [3]], dtype=dtypes.int32)
self._feature_1_quantiles = constant_op.constant(
[[2], [1], [1], [3], [3], [0]], dtype=dtypes.int32)
self._feature_0 = constant_op.constant([1.2, 12.1, 0.3, 0.5, 0.6, 2.2],
dtype=dtypes.float32)
self._feature_1 = constant_op.constant([2.3, 1.2, 1.1, 2.6, 3.2, 0.8],
dtype=dtypes.float32)
self._feature_0_boundaries = np.array([0.3, 0.6, 1.2, 12.1])
self._feature_1_boundaries = np.array([0.8, 1.2, 2.3, 3.2])
self._feature_0_quantiles = constant_op.constant([2, 3, 0, 1, 1, 3],
dtype=dtypes.int32)
self._feature_1_quantiles = constant_op.constant([2, 1, 1, 3, 3, 0],
dtype=dtypes.int32)
self._example_weights = constant_op.constant(
[10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
@ -135,6 +140,69 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
def testSaveRestoreAfterFlush(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
with self.test_session() as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run()
buckets = accumulator.get_bucket_boundaries()
self.assertAllClose([], buckets[0].eval())
self.assertAllClose([], buckets[1].eval())
summaries = accumulator.add_summaries([self._feature_0, self._feature_1],
self._example_weights)
with ops.control_dependencies([summaries]):
flush = accumulator.flush()
sess.run(flush)
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
save.save(sess, save_path)
with self.test_session(graph=ops.Graph()) as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
save = saver.Saver()
save.restore(sess, save_path)
buckets = accumulator.get_bucket_boundaries()
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
def testSaveRestoreBeforeFlush(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
with self.test_session() as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run()
summaries = accumulator.add_summaries([self._feature_0, self._feature_1],
self._example_weights)
sess.run(summaries)
buckets = accumulator.get_bucket_boundaries()
self.assertAllClose([], buckets[0].eval())
self.assertAllClose([], buckets[1].eval())
save.save(sess, save_path)
sess.run(accumulator.flush())
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
with self.test_session(graph=ops.Graph()) as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
save = saver.Saver()
save.restore(sess, save_path)
buckets = accumulator.get_bucket_boundaries()
self.assertAllClose([], buckets[0].eval())
self.assertAllClose([], buckets[1].eval())
if __name__ == "__main__":
googletest.main()

View File

@ -33,10 +33,13 @@ from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quant
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_deserialize as quantile_resource_deserialize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized
# pylint: enable=unused-import
from tensorflow.python.training import saver
@ -58,6 +61,69 @@ class PruningMode(object):
sorted(cls._map))))
class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for QuantileAccumulator.
The bucket boundaries are serialized and deserialized from checkpointing.
"""
def __init__(self,
epsilon,
num_streams,
num_quantiles,
name=None,
max_elements=None):
with ops.name_scope(name, 'QuantileAccumulator') as name:
self._eps = epsilon
self._num_streams = num_streams
self._num_quantiles = num_quantiles
self._resource_handle = quantile_resource_handle_op(
container='', shared_name=name, name=name)
self._create_op = create_quantile_stream_resource(self._resource_handle,
epsilon, num_streams)
is_initialized_op = is_quantile_resource_initialized(
self._resource_handle)
resources.register_resource(self._resource_handle, self._create_op,
is_initialized_op)
self._make_saveable(name)
def _make_saveable(self, name):
bucket_boundaries = get_bucket_boundaries(self._resource_handle,
self._num_streams)
slice_spec = ''
specs = []
for i in range(self._num_streams):
specs.append(
saver.BaseSaverBuilder.SaveSpec(
bucket_boundaries[i], slice_spec,
name + '_bucket_boundaries_' + str(i)))
super(QuantileAccumulator, self).__init__(self._resource_handle, specs,
name)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)
def restore(self, restored_tensors, unused_tensor_shapes):
bucket_boundaries = restored_tensors
with ops.control_dependencies([self._create_op]):
return quantile_resource_deserialize(
self._resource_handle, bucket_boundaries=bucket_boundaries)
def add_summaries(self, float_columns, example_weights):
summaries = make_quantile_summaries(float_columns, example_weights,
self._eps)
summary_op = quantile_add_summaries(self._resource_handle, summaries)
return summary_op
def flush(self):
return quantile_flush(self._resource_handle, self._num_quantiles)
def get_bucket_boundaries(self):
return get_bucket_boundaries(self._resource_handle, self._num_streams)
@property
def resource(self):
return self._resource_handle
class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for TreeEnsemble."""

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\', \'quantile_sketch_epsilon\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\', \'0.01\'], "
}
member_method {
name: "eval_dir"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\', \'quantile_sketch_epsilon\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\', \'0.01\'], "
}
member_method {
name: "eval_dir"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\', \'quantile_sketch_epsilon\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\', \'0.01\'], "
}
member_method {
name: "eval_dir"

View File

@ -22,7 +22,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\', \'quantile_sketch_epsilon\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\', \'0.01\'], "
}
member_method {
name: "eval_dir"