[tf.data] Moving tf.data test-only transformations to one module and marking the corresponding target test-only.
This CL also migrates all tf.data Python test targets from using `py_test` to using `tf_py_test` to avoid repetitive specification of `python_version` and `srcs_version`. PiperOrigin-RevId: 280590324 Change-Id: I648c2bf0b7c5bde888241d39866b0c51dcb1fd5d
This commit is contained in:
		
							parent
							
								
									919253c654
								
							
						
					
					
						commit
						e5402b6883
					
				@ -1,4 +1,4 @@
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "py_test")
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
 | 
			
		||||
 | 
			
		||||
package(
 | 
			
		||||
    default_visibility = ["//tensorflow:internal"],
 | 
			
		||||
@ -7,21 +7,22 @@ package(
 | 
			
		||||
 | 
			
		||||
exports_files(["LICENSE"])
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "meta_benchmark",
 | 
			
		||||
    srcs = ["meta_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "benchmark_base",
 | 
			
		||||
    srcs = ["benchmark_base.py"],
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
@ -30,75 +31,63 @@ py_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "batch_benchmark",
 | 
			
		||||
    srcs = ["batch_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        ":benchmark_base",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:sparse_tensor",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "filter_benchmark",
 | 
			
		||||
    srcs = ["filter_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        ":benchmark_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "from_tensor_slices_benchmark",
 | 
			
		||||
    srcs = ["from_tensor_slices_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        ":benchmark_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "list_files_benchmark",
 | 
			
		||||
    srcs = ["list_files_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        ":benchmark_base",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_benchmark",
 | 
			
		||||
    srcs = ["map_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        ":benchmark_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "range_benchmark",
 | 
			
		||||
    srcs = ["range_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        ":benchmark_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ import timeit
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.client import session
 | 
			
		||||
from tensorflow.python.data.experimental.ops import sleep
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
@ -61,7 +61,7 @@ class MetaBenchmark(test.Benchmark):
 | 
			
		||||
    dataset = self.setup_fast_dataset()
 | 
			
		||||
    self.iters = 1000
 | 
			
		||||
    # sleep for 1e-3s per iteration
 | 
			
		||||
    return dataset.apply(sleep.sleep(1000))
 | 
			
		||||
    return dataset.apply(testing.sleep(1000))
 | 
			
		||||
 | 
			
		||||
  def benchmark_slow_dataset_with_only_cpp_iterations(self):
 | 
			
		||||
    dataset = self.setup_slow_dataset()
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "py_test")
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
 | 
			
		||||
 | 
			
		||||
package(
 | 
			
		||||
    default_visibility = ["//tensorflow:internal"],
 | 
			
		||||
@ -7,59 +7,51 @@ package(
 | 
			
		||||
 | 
			
		||||
exports_files(["LICENSE"])
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "autotune_benchmark",
 | 
			
		||||
    srcs = ["autotune_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "choose_fastest_benchmark",
 | 
			
		||||
    srcs = ["choose_fastest_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "choose_fastest_branch_benchmark",
 | 
			
		||||
    srcs = ["choose_fastest_branch_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/benchmarks:benchmark_base",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:sleep",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "csv_dataset_benchmark",
 | 
			
		||||
    srcs = ["csv_dataset_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = ["no_pip"],
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:parsing_ops",
 | 
			
		||||
        "//tensorflow/python:platform",
 | 
			
		||||
@ -67,16 +59,15 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:readers",
 | 
			
		||||
        "//tensorflow/python/data/ops:readers",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = ["no_pip"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_and_batch_benchmark",
 | 
			
		||||
    srcs = ["map_and_batch_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/core:protos_all_py",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
@ -87,16 +78,13 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:batching",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_defun_benchmark",
 | 
			
		||||
    srcs = ["map_defun_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
@ -108,12 +96,11 @@ py_test(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_vectorization_benchmark",
 | 
			
		||||
    srcs = ["map_vectorization_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/core:protos_all_py",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
@ -124,17 +111,15 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//tensorflow/python/data/util:nest",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "matching_files_benchmark",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["matching_files_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
@ -142,61 +127,54 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:util",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:matching_files",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "optimize_benchmark",
 | 
			
		||||
    srcs = ["optimize_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "parallel_interleave_benchmark",
 | 
			
		||||
    srcs = ["parallel_interleave_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:interleave_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:sleep",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "rejection_resample_benchmark",
 | 
			
		||||
    srcs = ["rejection_resample_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = ["no_pip"],
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "@six_archive//:six",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:resampling",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "@six_archive//:six",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = ["no_pip"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "snapshot_dataset_benchmark",
 | 
			
		||||
    srcs = ["snapshot_dataset_benchmark.py"],
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
@ -207,16 +185,14 @@ py_test(
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:snapshot",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "unbatch_benchmark",
 | 
			
		||||
    srcs = ["unbatch_benchmark.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
@ -224,6 +200,5 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:batching",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.benchmarks import benchmark_base
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import sleep
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -78,7 +78,8 @@ class ChooseFastestBranchBenchmark(benchmark_base.DatasetBenchmarkBase):
 | 
			
		||||
  def benchmark_with_input_skew(self):
 | 
			
		||||
 | 
			
		||||
    def make_dataset(time_us, num_elements):
 | 
			
		||||
      return dataset_ops.Dataset.range(num_elements).apply(sleep.sleep(time_us))
 | 
			
		||||
      return dataset_ops.Dataset.range(num_elements).apply(
 | 
			
		||||
          testing.sleep(time_us))
 | 
			
		||||
 | 
			
		||||
    # Dataset with 100 elements that emulates performance characteristics of a
 | 
			
		||||
    # file-based dataset stored in remote storage, where the first element
 | 
			
		||||
@ -87,10 +88,10 @@ class ChooseFastestBranchBenchmark(benchmark_base.DatasetBenchmarkBase):
 | 
			
		||||
                                 0).concatenate(make_dataset(1, 100)).take(100)
 | 
			
		||||
 | 
			
		||||
    def slow_branch(dataset):
 | 
			
		||||
      return dataset.apply(sleep.sleep(10000))
 | 
			
		||||
      return dataset.apply(testing.sleep(10000))
 | 
			
		||||
 | 
			
		||||
    def fast_branch(dataset):
 | 
			
		||||
      return dataset.apply(sleep.sleep(10))
 | 
			
		||||
      return dataset.apply(testing.sleep(10))
 | 
			
		||||
 | 
			
		||||
    def benchmark(dataset, name):
 | 
			
		||||
      self.run_and_report_benchmark(
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ import time
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import interleave_ops
 | 
			
		||||
from tensorflow.python.data.experimental.ops import sleep
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
@ -52,7 +52,7 @@ def _make_fake_dataset_fn(initial_delay_us, remainder_delay_us):
 | 
			
		||||
    def make_dataset(time_us, num_elements):
 | 
			
		||||
      dataset = dataset_ops.Dataset.range(num_elements)
 | 
			
		||||
      if time_us > 0:
 | 
			
		||||
        dataset = dataset.apply(sleep.sleep(time_us))
 | 
			
		||||
        dataset = dataset.apply(testing.sleep(time_us))
 | 
			
		||||
      return dataset
 | 
			
		||||
 | 
			
		||||
    if not initial_delay_us:
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -17,7 +17,7 @@ from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import errors
 | 
			
		||||
@ -26,11 +26,11 @@ from tensorflow.python.platform import test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@test_util.run_all_in_graph_and_eager_modes
 | 
			
		||||
class AssertNextDatasetTest(test_base.DatasetTestBase):
 | 
			
		||||
class AssertNextTest(test_base.DatasetTestBase):
 | 
			
		||||
 | 
			
		||||
  def testAssertNext(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0).apply(
 | 
			
		||||
        optimization.assert_next(["Map"])).map(lambda x: x)
 | 
			
		||||
        testing.assert_next(["Map"])).map(lambda x: x)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    dataset = dataset.with_options(options)
 | 
			
		||||
@ -38,7 +38,7 @@ class AssertNextDatasetTest(test_base.DatasetTestBase):
 | 
			
		||||
 | 
			
		||||
  def testAssertNextInvalid(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0).apply(
 | 
			
		||||
        optimization.assert_next(["Whoops"])).map(lambda x: x)
 | 
			
		||||
        testing.assert_next(["Whoops"])).map(lambda x: x)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    dataset = dataset.with_options(options)
 | 
			
		||||
@ -51,7 +51,7 @@ class AssertNextDatasetTest(test_base.DatasetTestBase):
 | 
			
		||||
 | 
			
		||||
  def testAssertNextShort(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0).apply(
 | 
			
		||||
        optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
 | 
			
		||||
        testing.assert_next(["Map", "Whoops"])).map(lambda x: x)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.autotune = False
 | 
			
		||||
@ -23,7 +23,7 @@ from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_
 | 
			
		||||
from tensorflow.python.data.experimental.ops import distribute
 | 
			
		||||
from tensorflow.python.data.experimental.ops import distribute_options
 | 
			
		||||
from tensorflow.python.data.experimental.ops import interleave_ops
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.experimental.ops import readers
 | 
			
		||||
from tensorflow.python.data.experimental.ops import unique
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
@ -391,7 +391,7 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
 | 
			
		||||
    # Tests that Rebatch is a passthrough op.
 | 
			
		||||
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
 | 
			
		||||
    dataset = dataset.apply(
 | 
			
		||||
        optimization.assert_next(["Shard", "FlatMap", "BatchV2", "Rebatch"]))
 | 
			
		||||
        testing.assert_next(["Shard", "FlatMap", "BatchV2", "Rebatch"]))
 | 
			
		||||
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
 | 
			
		||||
    dataset = dataset.batch(5)
 | 
			
		||||
    dataset = distribute._RebatchDataset(dataset, num_replicas=1)
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import errors
 | 
			
		||||
@ -33,7 +33,7 @@ class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  def testAutotuneOption(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0)
 | 
			
		||||
    dataset = dataset.map(lambda x: x).apply(
 | 
			
		||||
        optimization.assert_next(["Model"]))
 | 
			
		||||
        testing.assert_next(["Model"]))
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.autotune = True
 | 
			
		||||
@ -0,0 +1,56 @@
 | 
			
		||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
"""Tests for `tf.data.experimental.non_serializable()`."""
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@test_util.run_all_in_graph_and_eager_modes
 | 
			
		||||
class NonSerializableTest(test_base.DatasetTestBase):
 | 
			
		||||
 | 
			
		||||
  def testNonSerializable(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0)
 | 
			
		||||
    dataset = dataset.apply(testing.assert_next(["FiniteSkip"]))
 | 
			
		||||
    dataset = dataset.skip(0)  # Should not be removed by noop elimination
 | 
			
		||||
    dataset = dataset.apply(testing.non_serializable())
 | 
			
		||||
    dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
 | 
			
		||||
    dataset = dataset.skip(0)  # Should be removed by noop elimination
 | 
			
		||||
    dataset = dataset.cache()
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.noop_elimination = True
 | 
			
		||||
    dataset = dataset.with_options(options)
 | 
			
		||||
    self.assertDatasetProduces(dataset, expected_output=[0])
 | 
			
		||||
 | 
			
		||||
  def testNonSerializableAsDirectInput(self):
 | 
			
		||||
    """Tests that non-serializable dataset can be OptimizeDataset's input."""
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0)
 | 
			
		||||
    dataset = dataset.apply(testing.non_serializable())
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.noop_elimination = True
 | 
			
		||||
    dataset = dataset.with_options(options)
 | 
			
		||||
    self.assertDatasetProduces(dataset, expected_output=[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "py_test")
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
 | 
			
		||||
 | 
			
		||||
package(
 | 
			
		||||
    default_visibility = ["//tensorflow:internal"],
 | 
			
		||||
@ -7,77 +7,64 @@ package(
 | 
			
		||||
 | 
			
		||||
exports_files(["LICENSE"])
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "assert_next_dataset_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
    srcs = ["assert_next_dataset_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "inject_prefetch_test",
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "choose_fastest_dataset_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["inject_prefetch_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
    srcs = ["choose_fastest_dataset_test.py"],
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "filter_fusion_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
    srcs = ["filter_fusion_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "filter_with_random_uniform_fusion_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
    srcs = ["filter_with_random_uniform_fusion_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = [
 | 
			
		||||
        "manual",
 | 
			
		||||
        "no_oss",
 | 
			
		||||
@ -85,32 +72,14 @@ py_test(
 | 
			
		||||
        "no_windows",
 | 
			
		||||
        "notap",  # TODO(b/131229793)
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "hoist_random_uniform_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["hoist_random_uniform_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:control_flow_ops",
 | 
			
		||||
@ -119,113 +88,103 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:random_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "latency_all_edges_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["latency_all_edges_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:stats_aggregator",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:stats_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "map_and_batch_fusion_test",
 | 
			
		||||
    srcs = ["map_and_batch_fusion_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_and_batch_fusion_test",
 | 
			
		||||
    srcs = ["map_and_batch_fusion_test.py"],
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "map_and_filter_fusion_test",
 | 
			
		||||
    srcs = ["map_and_filter_fusion_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_and_filter_fusion_test",
 | 
			
		||||
    srcs = ["map_and_filter_fusion_test.py"],
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "map_fusion_test",
 | 
			
		||||
    srcs = ["map_fusion_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_fusion_test",
 | 
			
		||||
    srcs = ["map_fusion_test.py"],
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_parallelization_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["map_parallelization_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:control_flow_ops",
 | 
			
		||||
@ -234,27 +193,25 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:random_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "map_vectorization_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["map_vectorization_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    shard_count = 8,
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "map_vectorization_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["map_vectorization_test.py"],
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "//tensorflow/core:protos_all_py",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:bitwise_ops",
 | 
			
		||||
@ -272,152 +229,54 @@ py_test(
 | 
			
		||||
        "//tensorflow/python:parsing_ops",
 | 
			
		||||
        "//tensorflow/python:sparse_tensor",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:batching",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "choose_fastest_dataset_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["choose_fastest_dataset_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    shard_count = 8,
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "choose_fastest_branch_dataset_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["choose_fastest_branch_dataset_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:batching",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "model_dataset_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
    srcs = ["model_dataset_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
        "optonly",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "noop_elimination_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["noop_elimination_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "optimize_dataset_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
    srcs = ["optimize_dataset_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:random_ops",
 | 
			
		||||
        "//tensorflow/python:variable_scope",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:batching",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:grouping",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:scan_ops",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//tensorflow/python/eager:context",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
tf_py_test(
 | 
			
		||||
    name = "shuffle_and_repeat_fusion_test",
 | 
			
		||||
    srcs = ["shuffle_and_repeat_fusion_test.py"],
 | 
			
		||||
    python_version = "PY2",
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    additional_deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
        "no_pip",
 | 
			
		||||
        "no_windows",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python/data/kernel_tests:test_base",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
@ -64,8 +64,8 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  @parameterized.named_parameters(*_filter_fusion_test_cases())
 | 
			
		||||
  def testFilterFusion(self, map_function, predicates):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5).apply(
 | 
			
		||||
        optimization.assert_next(["Map", "Filter",
 | 
			
		||||
                                  "MemoryCacheImpl"])).map(map_function)
 | 
			
		||||
        testing.assert_next(["Map", "Filter",
 | 
			
		||||
                             "MemoryCacheImpl"])).map(map_function)
 | 
			
		||||
    for predicate in predicates:
 | 
			
		||||
      dataset = dataset.filter(predicate)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
@ -30,7 +30,7 @@ class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase):
 | 
			
		||||
 | 
			
		||||
  def testFilterWithRandomUniformFusion(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10000000).apply(
 | 
			
		||||
        optimization.assert_next(["Sampling"]))
 | 
			
		||||
        testing.assert_next(["Sampling"]))
 | 
			
		||||
    dataset = dataset.filter(lambda _: random_ops.random_uniform([]) < 0.05)
 | 
			
		||||
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
@ -81,7 +81,7 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  @parameterized.named_parameters(*_hoist_random_uniform_test_cases())
 | 
			
		||||
  def testHoisting(self, function, will_optimize):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5).apply(
 | 
			
		||||
        optimization.assert_next(
 | 
			
		||||
        testing.assert_next(
 | 
			
		||||
            ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
 | 
			
		||||
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
@ -100,7 +100,7 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
          [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5).apply(
 | 
			
		||||
        optimization.assert_next(["Zip[0]", "Map"])).map(random_with_capture)
 | 
			
		||||
        testing.assert_next(["Zip[0]", "Map"])).map(random_with_capture)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.hoist_random_uniform = True
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.experimental.ops import stats_aggregator
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
@ -29,7 +29,7 @@ class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
 | 
			
		||||
  def testLatencyStatsOptimization(self):
 | 
			
		||||
    aggregator = stats_aggregator.StatsAggregator()
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(1).apply(
 | 
			
		||||
        optimization.assert_next(
 | 
			
		||||
        testing.assert_next(
 | 
			
		||||
            ["LatencyStats", "Map", "LatencyStats", "Prefetch",
 | 
			
		||||
             "LatencyStats"])).map(lambda x: x * x).prefetch(1)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
@ -29,7 +29,7 @@ class MapAndBatchFusionTest(test_base.DatasetTestBase):
 | 
			
		||||
 | 
			
		||||
  def testMapAndBatchFusion(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10).apply(
 | 
			
		||||
        optimization.assert_next(
 | 
			
		||||
        testing.assert_next(
 | 
			
		||||
            ["MapAndBatch"])).map(lambda x: x * x).batch(10)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
@ -80,8 +80,8 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  @parameterized.named_parameters(*_map_and_filter_fusion_test_cases())
 | 
			
		||||
  def testMapFilterFusion(self, function, predicate):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10).apply(
 | 
			
		||||
        optimization.assert_next(["Map", "Filter",
 | 
			
		||||
                                  "Map"])).map(function).filter(predicate)
 | 
			
		||||
        testing.assert_next(["Map", "Filter",
 | 
			
		||||
                             "Map"])).map(function).filter(predicate)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.map_and_filter_fusion = True
 | 
			
		||||
@ -99,8 +99,7 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
    # We are currently not supporting functions with captured inputs.
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10).apply(
 | 
			
		||||
        optimization.assert_next(["Map",
 | 
			
		||||
                                  "Filter"])).map(function).filter(predicate)
 | 
			
		||||
        testing.assert_next(["Map", "Filter"])).map(function).filter(predicate)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.map_and_filter_fusion = True
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
@ -68,7 +68,7 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  @parameterized.named_parameters(*_map_fusion_test_cases())
 | 
			
		||||
  def testMapFusion(self, functions):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5).apply(
 | 
			
		||||
        optimization.assert_next(["Map", "MemoryCacheImpl"]))
 | 
			
		||||
        testing.assert_next(["Map", "MemoryCacheImpl"]))
 | 
			
		||||
    for function in functions:
 | 
			
		||||
      dataset = dataset.map(function)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
@ -55,7 +55,7 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  def testMapParallelization(self, function, should_be_parallel):
 | 
			
		||||
    next_nodes = ["ParallelMap"] if should_be_parallel else ["Map"]
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5).apply(
 | 
			
		||||
        optimization.assert_next(next_nodes)).map(function)
 | 
			
		||||
        testing.assert_next(next_nodes)).map(function)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.map_parallelization = True
 | 
			
		||||
@ -70,7 +70,7 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    def fn(x):
 | 
			
		||||
      return x + captured_t
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5).apply(
 | 
			
		||||
        optimization.assert_next(["ParallelMap"])).map(fn)
 | 
			
		||||
        testing.assert_next(["ParallelMap"])).map(fn)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.map_parallelization = True
 | 
			
		||||
@ -85,7 +85,7 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    def fn(x):
 | 
			
		||||
      return x + captured_t
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5).apply(
 | 
			
		||||
        optimization.assert_next(["Map"])).map(fn)
 | 
			
		||||
        testing.assert_next(["Map"])).map(fn)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.map_parallelization = True
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ import numpy as np
 | 
			
		||||
from tensorflow.core.example import example_pb2
 | 
			
		||||
from tensorflow.core.example import feature_pb2
 | 
			
		||||
from tensorflow.python.data.experimental.ops import batching
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
@ -353,7 +353,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
 | 
			
		||||
 | 
			
		||||
    def _make_dataset(node_names):
 | 
			
		||||
      dataset = base_dataset.apply(optimization.assert_next(node_names))
 | 
			
		||||
      dataset = base_dataset.apply(testing.assert_next(node_names))
 | 
			
		||||
      dataset = dataset.map(map_fn, num_parallel_calls)
 | 
			
		||||
      dataset = dataset.batch(100)
 | 
			
		||||
      options = dataset_ops.Options()
 | 
			
		||||
@ -416,7 +416,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    base_dataset = base_dataset.with_options(options)
 | 
			
		||||
 | 
			
		||||
    def _make_dataset(node_names):
 | 
			
		||||
      dataset = base_dataset.apply(optimization.assert_next(node_names))
 | 
			
		||||
      dataset = base_dataset.apply(testing.assert_next(node_names))
 | 
			
		||||
      dataset = dataset.apply(batching.map_and_batch(map_fn, 100))
 | 
			
		||||
      return dataset
 | 
			
		||||
 | 
			
		||||
@ -464,7 +464,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    apply_fn_2 = make_apply_fn(fuse_second)
 | 
			
		||||
 | 
			
		||||
    def make_dataset(node_names):
 | 
			
		||||
      dataset = base_dataset.apply(optimization.assert_next(node_names))
 | 
			
		||||
      dataset = base_dataset.apply(testing.assert_next(node_names))
 | 
			
		||||
      dataset = apply_fn_1(dataset)
 | 
			
		||||
      dataset = apply_fn_2(dataset)
 | 
			
		||||
      return dataset
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
@ -37,7 +37,7 @@ class NoopEliminationTest(test_base.DatasetTestBase):
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5)
 | 
			
		||||
    dataset = dataset.apply(
 | 
			
		||||
        optimization.assert_next(
 | 
			
		||||
        testing.assert_next(
 | 
			
		||||
            ["FiniteRepeat", "FiniteSkip", "Prefetch", "MemoryCacheImpl"]))
 | 
			
		||||
    dataset = dataset.repeat(some_tensor).skip(5).take(-1).skip(0).repeat(
 | 
			
		||||
        1).prefetch(0).prefetch(1).cache()
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python import tf2
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
@ -37,7 +37,7 @@ class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase):
 | 
			
		||||
      expected = "ShuffleAndRepeat"
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10).apply(
 | 
			
		||||
        optimization.assert_next([expected])).shuffle(10).repeat(2)
 | 
			
		||||
        testing.assert_next([expected])).shuffle(10).repeat(2)
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.shuffle_and_repeat_fusion = True
 | 
			
		||||
 | 
			
		||||
@ -24,8 +24,8 @@ import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import batching
 | 
			
		||||
from tensorflow.python.data.experimental.ops import grouping
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import scan_ops
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.experimental.ops import threadpool
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
@ -146,7 +146,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
    def flat_map_fn(_):
 | 
			
		||||
      dataset = dataset_ops.Dataset.from_tensors(0)
 | 
			
		||||
      dataset = dataset.apply(optimization.assert_next(["MemoryCacheImpl"]))
 | 
			
		||||
      dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
 | 
			
		||||
      dataset = dataset.skip(0)  # Should be removed by noop elimination
 | 
			
		||||
      dataset = dataset.cache()
 | 
			
		||||
      return dataset
 | 
			
		||||
@ -163,7 +163,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
    def flat_map_fn(_):
 | 
			
		||||
      dataset = dataset_ops.Dataset.from_tensors(0)
 | 
			
		||||
      dataset = dataset.apply(optimization.assert_next(["MapAndBatch"]))
 | 
			
		||||
      dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
 | 
			
		||||
      # Should be fused by map and batch fusion
 | 
			
		||||
      dataset = dataset.map(lambda x: x)
 | 
			
		||||
      dataset = dataset.batch(1)
 | 
			
		||||
@ -194,30 +194,6 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
        expected_output=[list(range(10))],
 | 
			
		||||
        requires_initialization=True)
 | 
			
		||||
 | 
			
		||||
  def testOptimizationNonSerializable(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0)
 | 
			
		||||
    dataset = dataset.apply(optimization.assert_next(["FiniteSkip"]))
 | 
			
		||||
    dataset = dataset.skip(0)  # Should not be removed by noop elimination
 | 
			
		||||
    dataset = dataset.apply(optimization.non_serializable())
 | 
			
		||||
    dataset = dataset.apply(optimization.assert_next(["MemoryCacheImpl"]))
 | 
			
		||||
    dataset = dataset.skip(0)  # Should be removed by noop elimination
 | 
			
		||||
    dataset = dataset.cache()
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.noop_elimination = True
 | 
			
		||||
    dataset = dataset.with_options(options)
 | 
			
		||||
    self.assertDatasetProduces(dataset, expected_output=[0])
 | 
			
		||||
 | 
			
		||||
  def testOptimizationNonSerializableAsDirectInput(self):
 | 
			
		||||
    """Tests that non-serializable dataset can be OptimizeDataset's input."""
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensors(0)
 | 
			
		||||
    dataset = dataset.apply(optimization.non_serializable())
 | 
			
		||||
    options = dataset_ops.Options()
 | 
			
		||||
    options.experimental_optimization.apply_default_optimizations = False
 | 
			
		||||
    options.experimental_optimization.noop_elimination = True
 | 
			
		||||
    dataset = dataset.with_options(options)
 | 
			
		||||
    self.assertDatasetProduces(dataset, expected_output=[0])
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(_generate_captured_refvar_test_cases())
 | 
			
		||||
  @test_util.run_v1_only("RefVariables are not supported in eager mode.")
 | 
			
		||||
  def testSkipEagerOptimizationWithCapturedRefVar(self, dataset_fn):
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -21,7 +21,7 @@ import time
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import sleep
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import combinations
 | 
			
		||||
@ -36,7 +36,7 @@ class SleepTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    self.skipTest("b/123597912")
 | 
			
		||||
    sleep_microseconds = 100
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10).apply(
 | 
			
		||||
        sleep.sleep(sleep_microseconds))
 | 
			
		||||
        testing.sleep(sleep_microseconds))
 | 
			
		||||
    next_element = self.getNext(dataset)
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    for i in range(10):
 | 
			
		||||
@ -50,7 +50,7 @@ class SleepTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  def testSleepCancellation(self):
 | 
			
		||||
    sleep_microseconds = int(1e6) * 1000
 | 
			
		||||
    ds = dataset_ops.Dataset.range(1)
 | 
			
		||||
    ds = ds.apply(sleep.sleep(sleep_microseconds))
 | 
			
		||||
    ds = ds.apply(testing.sleep(sleep_microseconds))
 | 
			
		||||
    ds = ds.prefetch(1)
 | 
			
		||||
    get_next = self.getNext(ds, requires_initialization=True)
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@ class SleepTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
    sleep_microseconds = int(1e6) * 1000
 | 
			
		||||
    ds_sleep = dataset_ops.Dataset.range(1)
 | 
			
		||||
    ds_sleep = ds.apply(sleep.sleep(sleep_microseconds))
 | 
			
		||||
    ds_sleep = ds.apply(testing.sleep(sleep_microseconds))
 | 
			
		||||
 | 
			
		||||
    ds = ds.concatenate(ds_sleep)
 | 
			
		||||
    ds = ds.prefetch(1)
 | 
			
		||||
 | 
			
		||||
@ -332,16 +332,6 @@ py_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "sleep",
 | 
			
		||||
    srcs = ["sleep.py"],
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:experimental_dataset_ops_gen",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "snapshot",
 | 
			
		||||
    srcs = [
 | 
			
		||||
@ -405,6 +395,18 @@ py_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "testing",
 | 
			
		||||
    testonly = 1,
 | 
			
		||||
    srcs = ["testing.py"],
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:experimental_dataset_ops_gen",
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "threading_options",
 | 
			
		||||
    srcs = ["threading_options.py"],
 | 
			
		||||
@ -475,7 +477,6 @@ py_library(
 | 
			
		||||
        ":resampling",
 | 
			
		||||
        ":scan_ops",
 | 
			
		||||
        ":shuffle_ops",
 | 
			
		||||
        ":sleep",
 | 
			
		||||
        ":snapshot",
 | 
			
		||||
        ":stats_ops",
 | 
			
		||||
        ":take_while_ops",
 | 
			
		||||
 | 
			
		||||
@ -18,32 +18,9 @@ from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.ops import gen_experimental_dataset_ops
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
 | 
			
		||||
# account for indexing) and transformation sequence.
 | 
			
		||||
def assert_next(transformations):
 | 
			
		||||
  """A transformation that asserts which transformations happen next.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    transformations: A `tf.string` vector `tf.Tensor` identifying the
 | 
			
		||||
      transformations that are expected to happen next.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A `Dataset` transformation function, which can be passed to
 | 
			
		||||
    `tf.data.Dataset.apply`.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def _apply_fn(dataset):
 | 
			
		||||
    """Function from `Dataset` to `Dataset` that applies the transformation."""
 | 
			
		||||
    return _AssertNextDataset(dataset, transformations)
 | 
			
		||||
 | 
			
		||||
  return _apply_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def model():
 | 
			
		||||
  """A transformation that models performance.
 | 
			
		||||
 | 
			
		||||
@ -59,21 +36,6 @@ def model():
 | 
			
		||||
  return _apply_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def non_serializable():
 | 
			
		||||
  """A non-serializable identity transformation.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A `Dataset` transformation function, which can be passed to
 | 
			
		||||
    `tf.data.Dataset.apply`.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def _apply_fn(dataset):
 | 
			
		||||
    """Function from `Dataset` to `Dataset` that applies the transformation."""
 | 
			
		||||
    return _NonSerializableDataset(dataset)
 | 
			
		||||
 | 
			
		||||
  return _apply_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize(optimizations=None):
 | 
			
		||||
  """A transformation that applies optimizations.
 | 
			
		||||
 | 
			
		||||
@ -94,37 +56,6 @@ def optimize(optimizations=None):
 | 
			
		||||
  return _apply_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
 | 
			
		||||
  """A `Dataset` that asserts which transformations happen next."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_dataset, transformations):
 | 
			
		||||
    """See `assert_next()` for details."""
 | 
			
		||||
    self._input_dataset = input_dataset
 | 
			
		||||
    if transformations is None:
 | 
			
		||||
      raise ValueError("At least one transformation should be specified")
 | 
			
		||||
    self._transformations = ops.convert_to_tensor(
 | 
			
		||||
        transformations, dtype=dtypes.string, name="transformations")
 | 
			
		||||
    variant_tensor = (
 | 
			
		||||
        gen_experimental_dataset_ops.assert_next_dataset(
 | 
			
		||||
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
 | 
			
		||||
            self._transformations,
 | 
			
		||||
            **self._flat_structure))
 | 
			
		||||
    super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
 | 
			
		||||
  """A `Dataset` that performs non-serializable identity transformation."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_dataset):
 | 
			
		||||
    """See `non_serializable()` for details."""
 | 
			
		||||
    self._input_dataset = input_dataset
 | 
			
		||||
    variant_tensor = (
 | 
			
		||||
        gen_experimental_dataset_ops.non_serializable_dataset(
 | 
			
		||||
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
 | 
			
		||||
            **self._flat_structure))
 | 
			
		||||
    super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _ChooseFastestDataset(dataset_ops.DatasetV2):
 | 
			
		||||
  """A `Dataset` that merges two input datasets."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										123
									
								
								tensorflow/python/data/experimental/ops/testing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								tensorflow/python/data/experimental/ops/testing.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,123 @@
 | 
			
		||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
"""Experimental API for testing of tf.data."""
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.ops import gen_experimental_dataset_ops
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
 | 
			
		||||
# account for indexing) and transformation sequence.
 | 
			
		||||
def assert_next(transformations):
 | 
			
		||||
  """A transformation that asserts which transformations happen next.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    transformations: A `tf.string` vector `tf.Tensor` identifying the
 | 
			
		||||
      transformations that are expected to happen next.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A `Dataset` transformation function, which can be passed to
 | 
			
		||||
    `tf.data.Dataset.apply`.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def _apply_fn(dataset):
 | 
			
		||||
    """Function from `Dataset` to `Dataset` that applies the transformation."""
 | 
			
		||||
    return _AssertNextDataset(dataset, transformations)
 | 
			
		||||
 | 
			
		||||
  return _apply_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def non_serializable():
 | 
			
		||||
  """A non-serializable identity transformation.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A `Dataset` transformation function, which can be passed to
 | 
			
		||||
    `tf.data.Dataset.apply`.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def _apply_fn(dataset):
 | 
			
		||||
    """Function from `Dataset` to `Dataset` that applies the transformation."""
 | 
			
		||||
    return _NonSerializableDataset(dataset)
 | 
			
		||||
 | 
			
		||||
  return _apply_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sleep(sleep_microseconds):
 | 
			
		||||
  """Sleeps for `sleep_microseconds` before producing each input element.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    sleep_microseconds: The number of microseconds to sleep before producing an
 | 
			
		||||
      input element.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A `Dataset` transformation function, which can be passed to
 | 
			
		||||
    `tf.data.Dataset.apply`.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def _apply_fn(dataset):
 | 
			
		||||
    return _SleepDataset(dataset, sleep_microseconds)
 | 
			
		||||
 | 
			
		||||
  return _apply_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
 | 
			
		||||
  """A `Dataset` that asserts which transformations happen next."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_dataset, transformations):
 | 
			
		||||
    """See `assert_next()` for details."""
 | 
			
		||||
    self._input_dataset = input_dataset
 | 
			
		||||
    if transformations is None:
 | 
			
		||||
      raise ValueError("At least one transformation should be specified")
 | 
			
		||||
    self._transformations = ops.convert_to_tensor(
 | 
			
		||||
        transformations, dtype=dtypes.string, name="transformations")
 | 
			
		||||
    variant_tensor = (
 | 
			
		||||
        gen_experimental_dataset_ops.experimental_assert_next_dataset(
 | 
			
		||||
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
 | 
			
		||||
            self._transformations,
 | 
			
		||||
            **self._flat_structure))
 | 
			
		||||
    super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
 | 
			
		||||
  """A `Dataset` that performs non-serializable identity transformation."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_dataset):
 | 
			
		||||
    """See `non_serializable()` for details."""
 | 
			
		||||
    self._input_dataset = input_dataset
 | 
			
		||||
    variant_tensor = (
 | 
			
		||||
        gen_experimental_dataset_ops.experimental_non_serializable_dataset(
 | 
			
		||||
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
 | 
			
		||||
            **self._flat_structure))
 | 
			
		||||
    super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
 | 
			
		||||
  """A `Dataset` that sleeps before producing each upstream element."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_dataset, sleep_microseconds):
 | 
			
		||||
    self._input_dataset = input_dataset
 | 
			
		||||
    self._sleep_microseconds = sleep_microseconds
 | 
			
		||||
    variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
 | 
			
		||||
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
 | 
			
		||||
        self._sleep_microseconds,
 | 
			
		||||
        **self._flat_structure)
 | 
			
		||||
    super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -413,7 +413,7 @@ cuda_py_test(
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//tensorflow/python/data/ops:multi_device_iterator_ops",
 | 
			
		||||
        "//tensorflow/python/data/ops:iterator_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization_options",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
@ -555,7 +555,7 @@ cuda_py_test(
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:optimization",
 | 
			
		||||
        "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
        "//tensorflow/python:sparse_tensor",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
    ],
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.core.protobuf import config_pb2
 | 
			
		||||
from tensorflow.python.client import session
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.data.ops import multi_device_iterator_ops
 | 
			
		||||
@ -323,7 +323,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
 | 
			
		||||
  @combinations.generate(skip_v2_test_combinations())
 | 
			
		||||
  def testOptimization(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10)
 | 
			
		||||
    dataset = dataset.apply(optimization.assert_next(["MemoryCacheImpl"]))
 | 
			
		||||
    dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
 | 
			
		||||
    dataset = dataset.skip(0)  # this should be optimized away
 | 
			
		||||
    dataset = dataset.cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ import time
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import optimization
 | 
			
		||||
from tensorflow.python.data.experimental.ops import testing
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.eager import function
 | 
			
		||||
@ -239,7 +239,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
  @combinations.generate(test_base.default_test_combinations())
 | 
			
		||||
  def testOptions(self):
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(5)
 | 
			
		||||
    dataset = dataset.apply(optimization.assert_next(["MapAndBatch"]))
 | 
			
		||||
    dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
 | 
			
		||||
    dataset = dataset.map(lambda x: x).batch(5)
 | 
			
		||||
    self.evaluate(dataset.reduce(0, lambda state, value: state))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -82,6 +82,7 @@ DEPENDENCY_BLACKLIST = [
 | 
			
		||||
    "//tensorflow/core:image_testdata",
 | 
			
		||||
    "//tensorflow/core:lmdb_testdata",
 | 
			
		||||
    "//tensorflow/core/kernels/cloud:bigquery_reader_ops",
 | 
			
		||||
    "//tensorflow/python/data/experimental/ops:testing",
 | 
			
		||||
    "//tensorflow/python/debug:grpc_tensorflow_server.par",
 | 
			
		||||
    "//tensorflow/python/feature_column:vocabulary_testdata",
 | 
			
		||||
    "//tensorflow/python:framework/test_file_system.so",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user