Replace tf.{data.experimental,contrib.data}.unbatch() with Dataset.unbatch().

PiperOrigin-RevId: 274174768
This commit is contained in:
Derek Murray 2019-10-11 08:09:07 -07:00 committed by TensorFlower Gardener
parent fdc2e8a4f5
commit 9b82752179
9 changed files with 12 additions and 19 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmarks for `tf.data.experimental.unbatch()`."""
"""Benchmarks for `tf.data.Dataset.unbatch()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -22,7 +22,6 @@ import time
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -31,7 +30,7 @@ from tensorflow.python.platform import test
class UnbatchBenchmark(test.Benchmark):
"""Benchmarks for `tf.data.experimental.unbatch()`."""
"""Benchmarks for `tf.data.Dataset.unbatch()`."""
def benchmark_native_unbatch(self):
batch_sizes = [1, 2, 5, 10, 20, 50]
@ -40,7 +39,7 @@ class UnbatchBenchmark(test.Benchmark):
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset.batch(batch_size_placeholder)
dataset = dataset.apply(batching.unbatch())
dataset = dataset.unbatch()
dataset = dataset.skip(elems_per_trial)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False

View File

@ -19,7 +19,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@ -120,7 +119,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
def branch(dataset):
return dataset.apply(batching.unbatch())
return dataset.unbatch()
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch, branch],
@ -134,7 +133,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
def branch(dataset):
return dataset.apply(batching.unbatch())
return dataset.unbatch()
def make_dataset():
return optimization._ChooseFastestBranchDataset(

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@ -99,7 +98,7 @@ class ChooseFastestBranchDatasetSerializationTest(
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
def branch(dataset):
return dataset.apply(batching.unbatch())
return dataset.unbatch()
return optimization._ChooseFastestBranchDataset(
dataset, [branch, branch],

View File

@ -21,7 +21,6 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
@ -39,7 +38,7 @@ class UnbatchDatasetSerializationTest(
np.array(multiplier) * np.arange(tensor_slice_len))
return dataset_ops.Dataset.from_tensor_slices(components).batch(
batch_size).apply(batching.unbatch())
batch_size).unbatch()
@combinations.generate(test_base.default_test_combinations())
def testCore(self):

View File

@ -270,7 +270,7 @@ def unbatch():
# of a dataset.
a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
a.apply(tf.data.experimental.unbatch()) == {
a.unbatch() == {
'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
```

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
@ -182,7 +181,7 @@ def _estimate_initial_dist_ds(
initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
.apply(scan_ops.scan(initial_examples_per_class_seen,
update_estimate_and_tile))
.apply(batching.unbatch()))
.unbatch())
return initial_dist_ds

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.unbatch()`."""
"""Tests for `tf.data.Dataset.unbatch()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_lib
@ -457,7 +456,7 @@ def experimental_tpu_predict_loop(model,
padding_handler.update_mask)
dataset = dataset.map(padding_handler.pad_batch)
dataset = dataset.apply(batching.unbatch())
dataset = dataset.unbatch()
# Upon this point, it is guaranteed that the dataset does not
# have partial batches. Thus, we set `drop_remainder=True` to
# get static shape information about the elements in the dataset.

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
@ -194,6 +193,6 @@ def StreamingFilesDataset(files,
if batch_transfer_size:
# Undo the batching used during the transfer.
output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
output_dataset = output_dataset.unbatch().prefetch(1)
return output_dataset