Make non-iterable input to stratified_sample
produce better error message.
PiperOrigin-RevId: 163735979
This commit is contained in:
parent
122750a879
commit
b663c98991
@ -150,7 +150,7 @@ def stratified_sample(tensors,
|
|||||||
tensors: List of tensors for data. All tensors are either one item or a
|
tensors: List of tensors for data. All tensors are either one item or a
|
||||||
batch, according to enqueue_many.
|
batch, according to enqueue_many.
|
||||||
labels: Tensor for label of data. Label is a single integer or a batch,
|
labels: Tensor for label of data. Label is a single integer or a batch,
|
||||||
depending on enqueue_many. It is not a one-hot vector.
|
depending on `enqueue_many`. It is not a one-hot vector.
|
||||||
target_probs: Target class proportions in batch. An object whose type has a
|
target_probs: Target class proportions in batch. An object whose type has a
|
||||||
registered Tensor conversion function.
|
registered Tensor conversion function.
|
||||||
batch_size: Size of batch to be returned.
|
batch_size: Size of batch to be returned.
|
||||||
@ -164,9 +164,10 @@ def stratified_sample(tensors,
|
|||||||
examples and for the final queue with the proper class proportions.
|
examples and for the final queue with the proper class proportions.
|
||||||
name: Optional prefix for ops created by this function.
|
name: Optional prefix for ops created by this function.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: enqueue_many is True and labels doesn't have a batch
|
ValueError: If `tensors` isn't iterable.
|
||||||
dimension, or if enqueue_many is False and labels isn't a scalar.
|
ValueError: `enqueue_many` is True and labels doesn't have a batch
|
||||||
ValueError: enqueue_many is True, and batch dimension on data and labels
|
dimension, or if `enqueue_many` is False and labels isn't a scalar.
|
||||||
|
ValueError: `enqueue_many` is True, and batch dimension on data and labels
|
||||||
don't match.
|
don't match.
|
||||||
ValueError: if probs don't sum to one.
|
ValueError: if probs don't sum to one.
|
||||||
ValueError: if a zero initial probability class has a nonzero target
|
ValueError: if a zero initial probability class has a nonzero target
|
||||||
@ -188,7 +189,7 @@ def stratified_sample(tensors,
|
|||||||
# Run batch through network.
|
# Run batch through network.
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'stratified_sample', tensors + [labels]):
|
with ops.name_scope(name, 'stratified_sample', list(tensors) + [labels]):
|
||||||
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
|
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
|
||||||
labels = ops.convert_to_tensor(labels)
|
labels = ops.convert_to_tensor(labels)
|
||||||
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
|
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user