Make non-iterable input to stratified_sample produce better error message.

PiperOrigin-RevId: 163735979
This commit is contained in:
A. Unique TensorFlower 2017-07-31 13:05:12 -07:00 committed by TensorFlower Gardener
parent 122750a879
commit b663c98991

View File

@ -150,7 +150,7 @@ def stratified_sample(tensors,
tensors: List of tensors for data. All tensors are either one item or a tensors: List of tensors for data. All tensors are either one item or a
batch, according to enqueue_many. batch, according to enqueue_many.
labels: Tensor for label of data. Label is a single integer or a batch, labels: Tensor for label of data. Label is a single integer or a batch,
depending on enqueue_many. It is not a one-hot vector. depending on `enqueue_many`. It is not a one-hot vector.
target_probs: Target class proportions in batch. An object whose type has a target_probs: Target class proportions in batch. An object whose type has a
registered Tensor conversion function. registered Tensor conversion function.
batch_size: Size of batch to be returned. batch_size: Size of batch to be returned.
@ -164,9 +164,10 @@ def stratified_sample(tensors,
examples and for the final queue with the proper class proportions. examples and for the final queue with the proper class proportions.
name: Optional prefix for ops created by this function. name: Optional prefix for ops created by this function.
Raises: Raises:
ValueError: enqueue_many is True and labels doesn't have a batch ValueError: If `tensors` isn't iterable.
dimension, or if enqueue_many is False and labels isn't a scalar. ValueError: `enqueue_many` is True and labels doesn't have a batch
ValueError: enqueue_many is True, and batch dimension on data and labels dimension, or if `enqueue_many` is False and labels isn't a scalar.
ValueError: `enqueue_many` is True, and batch dimension on data and labels
don't match. don't match.
ValueError: if probs don't sum to one. ValueError: if probs don't sum to one.
ValueError: if a zero initial probability class has a nonzero target ValueError: if a zero initial probability class has a nonzero target
@ -188,7 +189,7 @@ def stratified_sample(tensors,
# Run batch through network. # Run batch through network.
... ...
""" """
with ops.name_scope(name, 'stratified_sample', tensors + [labels]): with ops.name_scope(name, 'stratified_sample', list(tensors) + [labels]):
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors) tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
labels = ops.convert_to_tensor(labels) labels = ops.convert_to_tensor(labels)
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32) target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)