Add shape check for MakeQuantileSummariesOp.

PiperOrigin-RevId: 161698801
This commit is contained in:
A. Unique TensorFlower 2017-07-12 12:18:38 -07:00 committed by TensorFlower Gardener
parent 2195db6d86
commit 9aa0dcbf28
3 changed files with 97 additions and 66 deletions
tensorflow/contrib/boosted_trees

View File

@ -382,6 +382,11 @@ class MakeQuantileSummariesOp : public OpKernel {
sparse_float_feature_values_list[sparse_index].flat<float>();
const auto sparse_indices =
sparse_float_feature_indices_list[sparse_index].matrix<int64>();
const auto dense_shape =
sparse_float_feature_shapes_list[sparse_index].flat<int64>();
OP_REQUIRES(context, batch_size == dense_shape(0),
errors::InvalidArgument(
"Sparse column shape doesn't match the batch size."));
QuantileStream stream(epsilon_, batch_size + 1);
// Run quantile summary generation.
const int64 num_sparse_rows =

View File

@ -20,6 +20,7 @@
namespace tensorflow {
namespace gtflow {
using shape_inference::InferenceContext;
using shape_inference::DimensionHandle;
using shape_inference::ShapeHandle;
REGISTER_RESOURCE_HANDLE_OP(QuantileStreamResource);
@ -172,7 +173,17 @@ REGISTER_OP("MakeQuantileSummaries")
int num_sparse_features;
TF_RETURN_IF_ERROR(
c->GetAttr("num_sparse_features", &num_sparse_features));
ShapeHandle example_weights_shape;
int example_weights_index = num_dense_features + num_sparse_features * 3;
TF_RETURN_IF_ERROR(c->WithRank(c->input(example_weights_index), 2,
&example_weights_shape));
for (int i = 0; i < num_dense_features; ++i) {
ShapeHandle dense_feature_shape;
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &dense_feature_shape));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(dense_feature_shape, 0),
c->Dim(example_weights_shape, 0),
&unused_dim));
c->set_output(i, c->Scalar());
}
for (int i = 0; i < num_sparse_features; ++i) {
@ -193,7 +204,8 @@ sparse_float_feature_values: List of rank 1 tensors containing the sparse float
feature values.
sparse_float_feature_shapes: List of rank 1 tensors containing the shape of the
float feature.
example_weights: Rank 1 tensor containing the example weight tensor.
example_weights: Rank 2 (N, 1) tensor of per-example weights. Should match
dense and sparse features shape.
dense_summaries: A list of serialized QuantileSummaryState for dense columns.
sparse_summaries: A list of serialized QuantileSummaryState for sparse columns.
)doc");

View File

@ -26,9 +26,12 @@ import numpy as np
from tensorflow.contrib.boosted_trees.proto.quantiles_pb2 import QuantileConfig
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resources
from tensorflow.python.platform import googletest
from tensorflow.python.training import saver
@ -56,12 +59,15 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
| 5 | 1 | 5 | 6
"""
dense_float_tensor_0 = np.array([1, 2, 3, 4, 4, 5])
sparse_indices_0 = np.array(
[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]], dtype=np.int64)
sparse_values_0 = np.array([2, 3, 4, 5, 6])
sparse_shape_0 = np.array([6, 1])
example_weights = np.array([10, 1, 1, 1, 1, 1])
dense_float_tensor_0 = constant_op.constant(
[1, 2, 3, 4, 4, 5], dtype=dtypes.float32)
sparse_indices_0 = constant_op.constant(
[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]], dtype=dtypes.int64)
sparse_values_0 = constant_op.constant(
[2, 3, 4, 5, 6], dtype=dtypes.float32)
sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64)
example_weights = constant_op.constant(
[10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
with self.test_session():
config = self._gen_config(0.33, 3)
@ -78,40 +84,38 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
def testStreamingQuantileBuckets(self):
"""Sets up the quantile summary op test as follows.
Create a batch of 6 examples having a dense and sparse features.
The data looks like this
| Instance | instance weights | Dense 0
| 0 | 10 | 1
| 1 | 1 | 2
| 2 | 1 | 3
| 3 | 1 | 4
| 4 | 1 | 4
| 5 | 1 | 5
100 batches of data is added to the accumulator. The batches are in form:
[0 1 .. 99]
[100 101 .. 200]
...
[9900 9901 .. 9999]
All the batches have 1 for all the example weights.
"""
dense_float_tensor_0 = np.array([1, 2, 3, 4, 4, 5])
example_weights = np.array([10, 1, 1, 1, 1, 1])
with self.test_session() as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q1")
init_stamp_token=0, num_quantiles=3, epsilon=0.01, name="q1")
resources.initialize_resources(resources.shared_resources()).run()
weight_placeholder = array_ops.placeholder(dtypes.float32)
dense_placeholder = array_ops.placeholder(dtypes.float32)
update = accumulator.add_summary(
stamp_token=0,
column=dense_placeholder,
example_weights=weight_placeholder)
with self.test_session() as sess:
for i in range(100):
dense_float = np.linspace(
i * 100, (i + 1) * 100 - 1, num=100).reshape(-1, 1)
sess.run(update, {
dense_placeholder: dense_float,
weight_placeholder: np.ones(shape=(100, 1), dtype=np.float32)
})
are_ready_noflush, _, = (accumulator.get_buckets(stamp_token=0))
update = accumulator.add_summary(
stamp_token=0,
column=dense_float_tensor_0,
example_weights=example_weights)
with ops.control_dependencies([are_ready_noflush, update]):
reset = accumulator.flush(stamp_token=0, next_stamp_token=1)
with ops.control_dependencies([reset]):
are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
buckets, are_ready_noflush, are_ready_flush = (sess.run(
[buckets, are_ready_noflush, are_ready_flush]))
self.assertEqual(False, are_ready_noflush)
with self.test_session() as sess:
sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
buckets, are_ready_flush = (sess.run([buckets, are_ready_flush]))
self.assertEqual(True, are_ready_flush)
self.assertAllEqual([1, 3, 5], buckets)
self.assertAllEqual([0, 3335., 6671., 9999.], buckets)
def testSaveRestoreBeforeFlush(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
@ -124,11 +128,13 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run()
sparse_indices_0 = np.array(
[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]], dtype=np.int64)
sparse_values_0 = [2.0, 3.0, 4.0, 5.0, 6.0]
sparse_shape_0 = np.array([6, 1])
example_weights = np.array([10, 1, 1, 1, 1, 1])
sparse_indices_0 = constant_op.constant(
[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]], dtype=dtypes.int64)
sparse_values_0 = constant_op.constant(
[2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtypes.float32)
sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64)
example_weights = constant_op.constant(
[10, 1, 1, 1, 1, 1], dtype=dtypes.float32, shape=[6, 1])
update = accumulator.add_summary(
stamp_token=0,
column=sparse_tensor.SparseTensor(sparse_indices_0, sparse_values_0,
@ -173,8 +179,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run()
example_weights = np.array([10, 1, 1, 1, 1, 1])
dense_float_tensor_0 = np.array([1, 2, 3, 4, 4, 5])
example_weights = constant_op.constant(
[10, 1, 1, 1, 1, 1], dtype=dtypes.float32, shape=[6, 1])
dense_float_tensor_0 = constant_op.constant(
[1, 2, 3, 4, 4, 5], dtype=dtypes.float32, shape=[6, 1])
update = accumulator.add_summary(
stamp_token=0,
column=dense_float_tensor_0,
@ -206,9 +214,11 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
Creates array dividing range [0, 1] to 1<<16 elements equally spaced
with weight of 1.0.
"""
dense_float_tensor_0 = np.array([(1.0 * i) / math.pow(
2.0, 16) for i in range(0, int(math.pow(2, 16)) + 1)])
example_weights = np.array([1] * (int(math.pow(2, 16)) + 1))
dense_float_tensor_0 = constant_op.constant(
[(1.0 * i) / math.pow(2.0, 16)
for i in range(0, int(math.pow(2, 16)) + 1)])
example_weights = constant_op.constant(
[1] * (int(math.pow(2, 16)) + 1), dtype=dtypes.float32)
config = self._gen_config(0.1, 10)
with self.test_session():
@ -228,10 +238,12 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
Creates array dividing range [0, 1] to 1<<16 elements equally spaced
with weight same as the value.
"""
dense_float_tensor_0 = np.array([(1.0 * i) / math.pow(
2.0, 16) for i in range(0, int(math.pow(2, 16)) + 1)])
example_weights = np.array([(1.0 * i) / math.pow(2.0, 16)
for i in range(0, int(math.pow(2, 16)) + 1)])
dense_float_tensor_0 = constant_op.constant(
[(1.0 * i) / math.pow(2.0, 16)
for i in range(0, int(math.pow(2, 16)) + 1)])
example_weights = constant_op.constant(
[(1.0 * i) / math.pow(2.0, 16)
for i in range(0, int(math.pow(2, 16)) + 1)])
config = self._gen_config(0.1, 10)
@ -267,28 +279,30 @@ class QuantilesOpTest(test_util.TensorFlowTestCase):
Sparse 2: (-inf, 100]
"""
super(QuantilesOpTest, self).setUp()
self._dense_float_tensor_0 = np.array([[-0.1], [0.4], [3.2], [190]])
self._dense_float_tensor_1 = np.array([[-1], [-15], [18], [1000]])
self._dense_float_tensor_0 = constant_op.constant(
[[-0.1], [0.4], [3.2], [190]], dtype=dtypes.float32)
self._dense_float_tensor_1 = constant_op.constant(
[[-1], [-15], [18], [1000]], dtype=dtypes.float32)
# Sparse feature 0
self._sparse_indices_0 = np.array([[0, 0], [1, 0], [2, 0], [3, 0]])
self._sparse_values_0 = np.array([-2, 5.5, 16, 17.5])
self._sparse_shape_0 = np.array([4, 1])
self._sparse_indices_0 = constant_op.constant([[0, 0], [1, 0], [2, 0],
[3, 0]])
self._sparse_values_0 = constant_op.constant([-2, 5.5, 16, 17.5])
self._sparse_shape_0 = constant_op.constant([4, 1])
# Sprase feature 1
self._sparse_indices_1 = np.array([[0, 0], [2, 0], [3, 0]])
self._sparse_values_1 = np.array([0.1, 3, -3])
self._sparse_shape_1 = np.array([4, 1])
self._sparse_indices_1 = constant_op.constant([[0, 0], [2, 0], [3, 0]])
self._sparse_values_1 = constant_op.constant([0.1, 3, -3])
self._sparse_shape_1 = constant_op.constant([4, 1])
# Sprase feature 2
self._sparse_indices_2 = np.array([[1, 0], [3, 0]])
self._sparse_values_2 = np.array([2, 4])
self._sparse_shape_2 = np.array([4, 1])
self._sparse_indices_2 = constant_op.constant([[1, 0], [3, 0]])
self._sparse_values_2 = constant_op.constant([2, 4], dtype=dtypes.float32)
self._sparse_shape_2 = constant_op.constant([4, 1])
# Quantiles
self._dense_thresholds_0 = np.array([0.4, 5, 190])
self._dense_thresholds_1 = np.array([-9, 15, 1000])
self._dense_thresholds_0 = [0.4, 5, 190]
self._dense_thresholds_1 = [-9, 15, 1000]
self._sparse_thresholds_0 = np.array([5, 16, 100])
self._sparse_thresholds_1 = np.array([2, 5])
self._sparse_thresholds_2 = np.array([100])
self._sparse_thresholds_0 = [5, 16, 100]
self._sparse_thresholds_1 = [2, 5]
self._sparse_thresholds_2 = [100]
def testDenseFeaturesOnly(self):
with self.test_session():