[tf.data] Delete unused code left over from tf.contrib.data.batch_and_drop_remainder()
.
PiperOrigin-RevId: 207491969
This commit is contained in:
parent
02ae1e2e78
commit
91f75c8ca7
@ -31,7 +31,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -439,48 +438,6 @@ def unbatch():
|
||||
return _apply_fn
|
||||
|
||||
|
||||
def _filter_irregular_batches(batch_size):
|
||||
"""Transformation that filters out batches that are not of size batch_size."""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||
tensor_batch_size = ops.convert_to_tensor(
|
||||
batch_size, dtype=dtypes.int64, name="batch_size")
|
||||
|
||||
flattened = _RestructuredDataset(
|
||||
dataset,
|
||||
tuple(nest.flatten(dataset.output_types)),
|
||||
output_classes=tuple(nest.flatten(dataset.output_classes)))
|
||||
|
||||
def _predicate(*xs):
|
||||
"""Return `True` if this element is a full batch."""
|
||||
# Extract the dynamic batch size from the first component of the flattened
|
||||
# batched element.
|
||||
first_component = xs[0]
|
||||
first_component_batch_size = array_ops.shape(
|
||||
first_component, out_type=dtypes.int64)[0]
|
||||
|
||||
return math_ops.equal(first_component_batch_size, tensor_batch_size)
|
||||
|
||||
filtered = flattened.filter(_predicate)
|
||||
|
||||
maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
|
||||
|
||||
def _set_first_dimension(shape):
|
||||
return shape.merge_with(
|
||||
tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
|
||||
|
||||
known_shapes = nest.map_structure(_set_first_dimension,
|
||||
dataset.output_shapes)
|
||||
return _RestructuredDataset(
|
||||
filtered,
|
||||
dataset.output_types,
|
||||
known_shapes,
|
||||
output_classes=dataset.output_classes)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.")
|
||||
def batch_and_drop_remainder(batch_size):
|
||||
|
Loading…
Reference in New Issue
Block a user