[tf.data] Delete unused code left over from tf.contrib.data.batch_and_drop_remainder().

PiperOrigin-RevId: 207491969
This commit is contained in:
Derek Murray 2018-08-05 22:16:44 -07:00 committed by TensorFlower Gardener
parent 02ae1e2e78
commit 91f75c8ca7

View File

@ -31,7 +31,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@ -439,48 +438,6 @@ def unbatch():
return _apply_fn
def _filter_irregular_batches(batch_size):
"""Transformation that filters out batches that are not of size batch_size."""
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
tensor_batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
flattened = _RestructuredDataset(
dataset,
tuple(nest.flatten(dataset.output_types)),
output_classes=tuple(nest.flatten(dataset.output_classes)))
def _predicate(*xs):
"""Return `True` if this element is a full batch."""
# Extract the dynamic batch size from the first component of the flattened
# batched element.
first_component = xs[0]
first_component_batch_size = array_ops.shape(
first_component, out_type=dtypes.int64)[0]
return math_ops.equal(first_component_batch_size, tensor_batch_size)
filtered = flattened.filter(_predicate)
maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
def _set_first_dimension(shape):
return shape.merge_with(
tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
known_shapes = nest.map_structure(_set_first_dimension,
dataset.output_shapes)
return _RestructuredDataset(
filtered,
dataset.output_types,
known_shapes,
output_classes=dataset.output_classes)
return _apply_fn
@deprecation.deprecated(
None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.")
def batch_and_drop_remainder(batch_size):