Replace tf.{data.experimental,contrib.data}.unbatch() with Dataset.unbatch().
PiperOrigin-RevId: 274174768
This commit is contained in:
parent
fdc2e8a4f5
commit
9b82752179
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Benchmarks for `tf.data.experimental.unbatch()`."""
|
"""Benchmarks for `tf.data.Dataset.unbatch()`."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -22,7 +22,6 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -31,7 +30,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
|
|
||||||
class UnbatchBenchmark(test.Benchmark):
|
class UnbatchBenchmark(test.Benchmark):
|
||||||
"""Benchmarks for `tf.data.experimental.unbatch()`."""
|
"""Benchmarks for `tf.data.Dataset.unbatch()`."""
|
||||||
|
|
||||||
def benchmark_native_unbatch(self):
|
def benchmark_native_unbatch(self):
|
||||||
batch_sizes = [1, 2, 5, 10, 20, 50]
|
batch_sizes = [1, 2, 5, 10, 20, 50]
|
||||||
@ -40,7 +39,7 @@ class UnbatchBenchmark(test.Benchmark):
|
|||||||
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
|
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
|
||||||
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
dataset = dataset.batch(batch_size_placeholder)
|
dataset = dataset.batch(batch_size_placeholder)
|
||||||
dataset = dataset.apply(batching.unbatch())
|
dataset = dataset.unbatch()
|
||||||
dataset = dataset.skip(elems_per_trial)
|
dataset = dataset.skip(elems_per_trial)
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
options.experimental_optimization.apply_default_optimizations = False
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.experimental.ops import optimization
|
from tensorflow.python.data.experimental.ops import optimization
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
@ -120,7 +119,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
|||||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||||
|
|
||||||
def branch(dataset):
|
def branch(dataset):
|
||||||
return dataset.apply(batching.unbatch())
|
return dataset.unbatch()
|
||||||
|
|
||||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
dataset, [branch, branch],
|
dataset, [branch, branch],
|
||||||
@ -134,7 +133,7 @@ class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
|||||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||||
|
|
||||||
def branch(dataset):
|
def branch(dataset):
|
||||||
return dataset.apply(batching.unbatch())
|
return dataset.unbatch()
|
||||||
|
|
||||||
def make_dataset():
|
def make_dataset():
|
||||||
return optimization._ChooseFastestBranchDataset(
|
return optimization._ChooseFastestBranchDataset(
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.experimental.ops import optimization
|
from tensorflow.python.data.experimental.ops import optimization
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
@ -99,7 +98,7 @@ class ChooseFastestBranchDatasetSerializationTest(
|
|||||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||||
|
|
||||||
def branch(dataset):
|
def branch(dataset):
|
||||||
return dataset.apply(batching.unbatch())
|
return dataset.unbatch()
|
||||||
|
|
||||||
return optimization._ChooseFastestBranchDataset(
|
return optimization._ChooseFastestBranchDataset(
|
||||||
dataset, [branch, branch],
|
dataset, [branch, branch],
|
||||||
|
@ -21,7 +21,6 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
@ -39,7 +38,7 @@ class UnbatchDatasetSerializationTest(
|
|||||||
np.array(multiplier) * np.arange(tensor_slice_len))
|
np.array(multiplier) * np.arange(tensor_slice_len))
|
||||||
|
|
||||||
return dataset_ops.Dataset.from_tensor_slices(components).batch(
|
return dataset_ops.Dataset.from_tensor_slices(components).batch(
|
||||||
batch_size).apply(batching.unbatch())
|
batch_size).unbatch()
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testCore(self):
|
def testCore(self):
|
||||||
|
@ -270,7 +270,7 @@ def unbatch():
|
|||||||
# of a dataset.
|
# of a dataset.
|
||||||
a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
|
a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
|
||||||
|
|
||||||
a.apply(tf.data.experimental.unbatch()) == {
|
a.unbatch() == {
|
||||||
'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
|
'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||||
from tensorflow.python.data.experimental.ops import scan_ops
|
from tensorflow.python.data.experimental.ops import scan_ops
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
@ -182,7 +181,7 @@ def _estimate_initial_dist_ds(
|
|||||||
initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
|
initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
|
||||||
.apply(scan_ops.scan(initial_examples_per_class_seen,
|
.apply(scan_ops.scan(initial_examples_per_class_seen,
|
||||||
update_estimate_and_tile))
|
update_estimate_and_tile))
|
||||||
.apply(batching.unbatch()))
|
.unbatch())
|
||||||
|
|
||||||
return initial_dist_ds
|
return initial_dist_ds
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for `tf.data.experimental.unbatch()`."""
|
"""Tests for `tf.data.Dataset.unbatch()`."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
@ -457,7 +456,7 @@ def experimental_tpu_predict_loop(model,
|
|||||||
padding_handler.update_mask)
|
padding_handler.update_mask)
|
||||||
|
|
||||||
dataset = dataset.map(padding_handler.pad_batch)
|
dataset = dataset.map(padding_handler.pad_batch)
|
||||||
dataset = dataset.apply(batching.unbatch())
|
dataset = dataset.unbatch()
|
||||||
# Upon this point, it is guaranteed that the dataset does not
|
# Upon this point, it is guaranteed that the dataset does not
|
||||||
# have partial batches. Thus, we set `drop_remainder=True` to
|
# have partial batches. Thus, we set `drop_remainder=True` to
|
||||||
# get static shape information about the elements in the dataset.
|
# get static shape information about the elements in the dataset.
|
||||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
|
||||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
@ -194,6 +193,6 @@ def StreamingFilesDataset(files,
|
|||||||
|
|
||||||
if batch_transfer_size:
|
if batch_transfer_size:
|
||||||
# Undo the batching used during the transfer.
|
# Undo the batching used during the transfer.
|
||||||
output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
|
output_dataset = output_dataset.unbatch().prefetch(1)
|
||||||
|
|
||||||
return output_dataset
|
return output_dataset
|
||||||
|
Loading…
Reference in New Issue
Block a user