Make non-iterable input to stratified_sample produce better error message.

PiperOrigin-RevId: 163735979
This commit is contained in:
A. Unique TensorFlower 2017-07-31 13:05:12 -07:00 committed by TensorFlower Gardener
parent 122750a879
commit b663c98991

View File

@ -150,7 +150,7 @@ def stratified_sample(tensors,
tensors: List of tensors for data. All tensors are either one item or a
batch, according to enqueue_many.
labels: Tensor for label of data. Label is a single integer or a batch,
depending on enqueue_many. It is not a one-hot vector.
depending on `enqueue_many`. It is not a one-hot vector.
target_probs: Target class proportions in batch. An object whose type has a
registered Tensor conversion function.
batch_size: Size of batch to be returned.
@ -164,9 +164,10 @@ def stratified_sample(tensors,
examples and for the final queue with the proper class proportions.
name: Optional prefix for ops created by this function.
Raises:
ValueError: enqueue_many is True and labels doesn't have a batch
dimension, or if enqueue_many is False and labels isn't a scalar.
ValueError: enqueue_many is True, and batch dimension on data and labels
ValueError: If `tensors` isn't iterable.
ValueError: `enqueue_many` is True and labels doesn't have a batch
dimension, or if `enqueue_many` is False and labels isn't a scalar.
ValueError: `enqueue_many` is True, and batch dimension on data and labels
don't match.
ValueError: if probs don't sum to one.
ValueError: if a zero initial probability class has a nonzero target
@ -188,7 +189,7 @@ def stratified_sample(tensors,
# Run batch through network.
...
"""
with ops.name_scope(name, 'stratified_sample', tensors + [labels]):
with ops.name_scope(name, 'stratified_sample', list(tensors) + [labels]):
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
labels = ops.convert_to_tensor(labels)
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)