[tf.data] Cleanup of tf.data op definitions.
cl/223710135 renamed some datasets ops (adding "Experimental" prefix), which introduced a backwards compatibility issue. In particular, if a graph was saved using TensorFlow before cl/223710135, restoring the graph using TensorFlow built after cl/223710135 would fail if the graph contained any of the renamed ops. To address this issue, this CL reintroduces op definitions removed by cl/223710135 and uses `compat.forward_compatible` to transition away from ops that use the (misleading) "Experimental" prefix to the corresponding ones without the prefix. In addition, this CL removes tf.data op definitions for ops that have no kernels. PiperOrigin-RevId: 256436076
This commit is contained in:
parent
4cca406165
commit
de9c460b06
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers as core_readers
|
||||
@ -390,8 +391,12 @@ class LMDBDataset(dataset_ops.DatasetSource):
|
||||
"""
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset(
|
||||
self._filenames, **self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.lmdb_dataset(
|
||||
self._filenames, **self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset(
|
||||
self._filenames, **self._flat_structure)
|
||||
super(LMDBDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -41,12 +42,20 @@ class _SlideDataset(dataset_ops.UnaryDataset):
|
||||
input_structure = dataset_ops.get_structure(input_dataset)
|
||||
self._element_spec = nest.map_structure(
|
||||
lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_sliding_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
window_size=self._window_size,
|
||||
window_shift=self._window_shift,
|
||||
window_stride=self._window_stride,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.sliding_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
window_size=self._window_size,
|
||||
window_shift=self._window_shift,
|
||||
window_stride=self._window_stride,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_sliding_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
window_size=self._window_size,
|
||||
window_shift=self._window_shift,
|
||||
window_stride=self._window_stride,
|
||||
**self._flat_structure)
|
||||
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "AssertNextDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
op {
|
||||
graph_op_name: "AutoShardDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_workers"
|
||||
description: <<END
|
||||
A scalar representing the number of workers to distribute this dataset across.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "index"
|
||||
description: <<END
|
||||
A scalar representing the index of the current worker out of num_workers.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that shards the input dataset."
|
||||
description: <<END
|
||||
Creates a dataset that shards the input dataset by num_workers, returning a
|
||||
sharded dataset for the index-th worker. This attempts to automatically shard
|
||||
a dataset by examining the Dataset graph and inserting a shard op before the
|
||||
inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset).
|
||||
|
||||
This dataset will throw a NotFound error if we cannot shard the dataset
|
||||
automatically.
|
||||
END
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "BytesProducedStatsDataset"
|
||||
visibility: HIDDEN
|
||||
summary: "Records the bytes size of each element of `input_dataset` in a StatsAggregator."
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "CSVDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ChooseFastestDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
op {
|
||||
graph_op_name: "DatasetCardinality"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the dataset to return cardinality for.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "cardinality"
|
||||
description: <<END
|
||||
The cardinality of `input_dataset`. Named constants are used to represent
|
||||
infinite and unknown cardinality.
|
||||
END
|
||||
}
|
||||
summary: "Returns the cardinality of `input_dataset`."
|
||||
description: <<END
|
||||
Returns the cardinality of `input_dataset`.
|
||||
END
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
op {
|
||||
graph_op_name: "DatasetToTFRecord"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the dataset to write.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "filename"
|
||||
description: <<END
|
||||
A scalar string tensor representing the filename to use.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "compression_type"
|
||||
description: <<END
|
||||
A scalar string tensor containing either (i) the empty string (no
|
||||
compression), (ii) "ZLIB", or (iii) "GZIP".
|
||||
END
|
||||
}
|
||||
summary: "Writes the given dataset to the given file using the TFRecord format."
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
op {
|
||||
graph_op_name: "DenseToSparseBatchDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A handle to an input dataset. Must have a single component.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "batch_size"
|
||||
description: <<END
|
||||
A scalar representing the number of elements to accumulate in a
|
||||
batch.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "row_shape"
|
||||
description: <<END
|
||||
A vector representing the dense shape of each row in the produced
|
||||
SparseTensor. The shape may be partially specified, using `-1` to indicate
|
||||
that a particular dimension should use the maximum size of all batch elements.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that batches input elements into a SparseTensor."
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
op {
|
||||
graph_op_name: "DirectedInterleaveDataset"
|
||||
in_arg {
|
||||
name: "selector_input_dataset"
|
||||
description: <<END
|
||||
A dataset of scalar `DT_INT64` elements that determines which of the
|
||||
`N` data inputs should produce the next output element.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "data_input_datasets"
|
||||
description: <<END
|
||||
`N` datasets with the same type that will be interleaved according to
|
||||
the values of `selector_input_dataset`.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalIdentityIndexedDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalIndexedDatasetGet"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalIndexedDatasetMaterialize"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
op {
|
||||
graph_op_name: "GroupByReducerDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key_func_other_arguments"
|
||||
description: <<END
|
||||
A list of tensors, typically values that were captured when
|
||||
building a closure for `key_func`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "key_func"
|
||||
description: <<END
|
||||
A function mapping an element of `input_dataset`, concatenated
|
||||
with `key_func_other_arguments` to a scalar value of type DT_INT64.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "init_func_other_arguments"
|
||||
description: <<END
|
||||
A list of tensors, typically values that were captured when
|
||||
building a closure for `init_func`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "init_func"
|
||||
description: <<END
|
||||
A function mapping a key of type DT_INT64, concatenated with
|
||||
`init_func_other_arguments` to the initial reducer state.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "reduce_func_other_arguments"
|
||||
description: <<END
|
||||
A list of tensors, typically values that were captured when
|
||||
building a closure for `reduce_func`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "reduce_func"
|
||||
description: <<END
|
||||
A function mapping the current reducer state and an element of `input_dataset`,
|
||||
concatenated with `reduce_func_other_arguments` to a new reducer state.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "finalize_func_other_arguments"
|
||||
description: <<END
|
||||
A list of tensors, typically values that were captured when
|
||||
building a closure for `finalize_func`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "finalize_func"
|
||||
description: <<END
|
||||
A function mapping the final reducer state to an output element.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that computes a group-by on `input_dataset`."
|
||||
description: <<END
|
||||
Creates a dataset that computes a group-by on `input_dataset`.
|
||||
END
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
op {
|
||||
graph_op_name: "GroupByWindowDataset"
|
||||
visibility: HIDDEN
|
||||
attr {
|
||||
name: "key_func"
|
||||
description: <<END
|
||||
A function mapping an element of `input_dataset`, concatenated
|
||||
with `key_func_other_arguments` to a scalar value of type DT_INT64.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that computes a windowed group-by on `input_dataset`."
|
||||
description: <<END
|
||||
// TODO(mrry): Support non-int64 keys.
|
||||
END
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "IgnoreErrorsDataset"
|
||||
summary: <<END
|
||||
Creates a dataset that contains the elements of `input_dataset` ignoring errors.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
||||
|
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "IteratorGetDevice"
|
||||
summary: <<END
|
||||
Returns the name of the device on which `resource` has been placed.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "LMDBDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "LatencyStatsDataset"
|
||||
visibility: HIDDEN
|
||||
summary: "Records the latency of producing `input_dataset` elements in a StatsAggregator."
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalNumaMapAndBatchDataset"
|
||||
graph_op_name: "MapAndBatchDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
@ -50,9 +50,5 @@ batches `batch_size` of them.
|
||||
|
||||
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
|
||||
to `batch_size * num_parallel_batches` copies of `f` in parallel.
|
||||
|
||||
Unlike "MapAndBatchDatasetV2", this dataset uses a NUMA-aware thread scheduling
|
||||
policy. Because it uses the single-threaded executor, it only supports the
|
||||
function-based control flow ops.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "MatchingFilesDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
op {
|
||||
graph_op_name: "MaxIntraOpParallelismDataset"
|
||||
in_arg {
|
||||
name: "max_intra_op_parallelism"
|
||||
description: <<END
|
||||
Identifies the maximum intra-op parallelism to use.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Creates a dataset that overrides the maximum intra-op parallelism.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "NonSerializableDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
op {
|
||||
graph_op_name: "ParallelInterleaveDataset"
|
||||
visibility: HIDDEN
|
||||
attr {
|
||||
name: "f"
|
||||
description: <<END
|
||||
A function mapping elements of `input_dataset`, concatenated with
|
||||
`other_arguments`, to a Dataset variant that contains elements matching
|
||||
`output_types` and `output_shapes`.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||
description: <<END
|
||||
The resulting dataset is similar to the `InterleaveDataset`, with the exception
|
||||
that if retrieving the next value from a dataset would cause the requester to
|
||||
block, it will skip that input dataset. This dataset is especially useful
|
||||
when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
|
||||
allows the training step to proceed so long as some data is available.
|
||||
|
||||
!! WARNING !! This dataset is not deterministic!
|
||||
END
|
||||
}
|
@ -0,0 +1,70 @@
|
||||
op {
|
||||
graph_op_name: "ParseExampleDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "dense_defaults"
|
||||
description: <<END
|
||||
A dict mapping string keys to `Tensor`s.
|
||||
The keys of the dict must match the dense_keys of the feature.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "sparse_keys"
|
||||
description: <<END
|
||||
A list of string keys in the examples features.
|
||||
The results for these keys will be returned as `SparseTensor` objects.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dense_keys"
|
||||
description: <<END
|
||||
A list of Ndense string Tensors (scalars).
|
||||
The keys expected in the Examples features associated with dense values.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "sparse_types"
|
||||
description: <<END
|
||||
A list of `DTypes` of the same length as `sparse_keys`.
|
||||
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
|
||||
and `tf.string` (`BytesList`) are supported.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Tdense"
|
||||
description: <<END
|
||||
A list of DTypes of the same length as `dense_keys`.
|
||||
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
|
||||
and `tf.string` (`BytesList`) are supported.
|
||||
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dense_shapes"
|
||||
description: <<END
|
||||
List of tuples with the same length as `dense_keys`.
|
||||
The shape of the data for each dense feature referenced by `dense_keys`.
|
||||
Required for any input tensors identified by `dense_keys`. Must be
|
||||
either fully defined, or may contain an unknown first dimension.
|
||||
An unknown first dimension means the feature is treated as having
|
||||
a variable number of blocks, and the output shape along this dimension
|
||||
is considered unknown at graph build time. Padding is applied for
|
||||
minibatch elements smaller than the maximum number of blocks for the
|
||||
given feature along this dimension.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
description: <<END
|
||||
The type list for the return values.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
description: <<END
|
||||
The list of shapes being produced.
|
||||
END
|
||||
}
|
||||
summary: "Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features."
|
||||
}
|
||||
|
@ -0,0 +1,13 @@
|
||||
op {
|
||||
graph_op_name: "PrivateThreadPoolDataset"
|
||||
in_arg {
|
||||
name: "num_threads"
|
||||
description: <<END
|
||||
Identifies the number of threads to use for the private threadpool.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
19
tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt
Normal file
19
tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt
Normal file
@ -0,0 +1,19 @@
|
||||
op {
|
||||
graph_op_name: "RandomDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "seed"
|
||||
description: <<END
|
||||
A scalar seed for the random number generator. If either seed or
|
||||
seed2 is set to be non-zero, the random number generator is seeded
|
||||
by the given seed. Otherwise, a random seed is used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "seed2"
|
||||
description: <<END
|
||||
A second scalar seed to avoid seed collision.
|
||||
END
|
||||
}
|
||||
summary: "Creates a Dataset that returns pseudorandom numbers."
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
op {
|
||||
graph_op_name: "RebatchDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_workers"
|
||||
description: <<END
|
||||
A scalar representing the number of workers to distribute this batch across. As
|
||||
a result of this transformation the current batch size would end up being
|
||||
divided by this parameter.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that changes the batch size."
|
||||
description: <<END
|
||||
Creates a dataset that changes the batch size of the dataset to current batch
|
||||
size // num_workers.
|
||||
END
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "ScanDataset"
|
||||
visibility: HIDDEN
|
||||
summary: "Creates a dataset successively reduces `f` over the elements of `input_dataset`."
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SetStatsAggregatorDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SleepDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
op {
|
||||
graph_op_name: "SlidingWindowDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "window_size"
|
||||
description: <<END
|
||||
A scalar representing the number of elements in the
|
||||
sliding window.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "window_shift"
|
||||
description: <<END
|
||||
A scalar representing the steps moving the sliding window
|
||||
forward in one iteration. It must be positive.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "window_stride"
|
||||
description: <<END
|
||||
A scalar representing the stride of the input elements of the sliding window.
|
||||
It must be positive.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that passes a sliding window over `input_dataset`."
|
||||
}
|
23
tensorflow/core/api_def/base_api/api_def_SqlDataset.pbtxt
Normal file
23
tensorflow/core/api_def/base_api/api_def_SqlDataset.pbtxt
Normal file
@ -0,0 +1,23 @@
|
||||
op {
|
||||
graph_op_name: "SqlDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "driver_name"
|
||||
description: <<END
|
||||
The database type. Currently, the only supported type is 'sqlite'.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "data_source_name"
|
||||
description: <<END
|
||||
A connection string to connect to the database.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "query"
|
||||
description: <<END
|
||||
A SQL query to execute.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that executes a SQL query and emits rows of the result set."
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "StatsAggregatorHandle"
|
||||
visibility: HIDDEN
|
||||
summary: "Creates a statistics manager resource."
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "StatsAggregatorSummary"
|
||||
visibility: HIDDEN
|
||||
summary: "Produces a summary of any statistics recorded by the given statistics manager."
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
op {
|
||||
graph_op_name: "TakeWhileDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "other_arguments"
|
||||
description: <<END
|
||||
A list of tensors, typically values that were captured when
|
||||
building a closure for `predicate`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "predicate"
|
||||
description: <<END
|
||||
A function returning a scalar boolean.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that stops iteration when predicate` is false."
|
||||
description: <<END
|
||||
The `predicate` function must return a scalar boolean and accept the
|
||||
following arguments:
|
||||
|
||||
* One tensor for each component of an element of `input_dataset`.
|
||||
* One tensor for each value in `other_arguments`.
|
||||
END
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
op {
|
||||
graph_op_name: "ThreadPoolDataset"
|
||||
in_arg {
|
||||
name: "thread_pool"
|
||||
description: <<END
|
||||
A resource produced by the ThreadPoolHandle op.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
op {
|
||||
graph_op_name: "ThreadPoolHandle"
|
||||
out_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
A resource that can be consumed by one or more ExperimentalThreadPoolDataset
|
||||
ops.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "num_threads"
|
||||
description: <<END
|
||||
The number of threads in the thread pool.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "max_intra_op_parallelism"
|
||||
description: <<END
|
||||
The maximum degree of parallelism to use within operations that execute on this
|
||||
threadpool.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "display_name"
|
||||
description: <<END
|
||||
A human-readable name for the threads that may be visible in some
|
||||
visualizations.
|
||||
threadpool.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "UnbatchDataset"
|
||||
visibility: HIDDEN
|
||||
summary: "A dataset that splits the elements of its input into multiple elements."
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "UniqueDataset"
|
||||
summary: <<END
|
||||
Creates a dataset that contains the unique elements of `input_dataset`.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
||||
|
@ -1,6 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalIdentityIndexedDataset"
|
||||
endpoint {
|
||||
name: "data.ExperimentalIdentityIndexedDataset"
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalIndexedDatasetGet"
|
||||
endpoint {
|
||||
name: "data.ExperimentalIndexedDatasetGet"
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalIndexedDatasetMaterialize"
|
||||
endpoint {
|
||||
name: "data.ExperimentalIndexedDatasetMaterialize"
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
|
||||
endpoint {
|
||||
name: "data.ExperimentalMaterializedIndexDatasetHandle"
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalNumaMapAndBatchDataset"
|
||||
endpoint {
|
||||
name: "data.ExperimentalNumaMapAndBatchDataset"
|
||||
}
|
||||
}
|
@ -56,6 +56,8 @@ class DatasetCardinalityOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
|
||||
DatasetToGraphOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
|
||||
DatasetCardinalityOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
|
||||
DatasetCardinalityOp);
|
||||
|
@ -155,6 +155,8 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
|
||||
AssertNextDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
|
||||
AssertNextDatasetOp);
|
||||
|
@ -76,6 +76,8 @@ class AutoShardDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AutoShardDataset").Device(DEVICE_CPU),
|
||||
AutoShardDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU),
|
||||
AutoShardDatasetOp);
|
||||
|
||||
|
@ -357,7 +357,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
}; // class ChooseFastestDatasetOp
|
||||
|
||||
// Register the kernel implementation for ChooseFastestDataset.
|
||||
REGISTER_KERNEL_BUILDER(Name("ChooseFastestDataset").Device(DEVICE_CPU),
|
||||
ChooseFastestDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalChooseFastestDataset").Device(DEVICE_CPU),
|
||||
ChooseFastestDatasetOp);
|
||||
|
@ -854,7 +854,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
}; // class CSVDatasetOp
|
||||
|
||||
// Register the kernel implementation for CSVDataset.
|
||||
REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
|
||||
CSVDatasetOp);
|
||||
|
||||
|
@ -312,6 +312,8 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU),
|
||||
DenseToSparseBatchDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalDenseToSparseBatchDataset").Device(DEVICE_CPU),
|
||||
DenseToSparseBatchDatasetOp);
|
||||
|
@ -277,6 +277,8 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
|
||||
DirectedInterleaveDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
|
||||
DirectedInterleaveDatasetOp);
|
||||
|
@ -412,9 +412,13 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
|
||||
GroupByReducerDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalGroupByReducerDataset").Device(DEVICE_CPU),
|
||||
GroupByReducerDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("GroupByReducerDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByReducerDataset");
|
||||
|
||||
} // namespace
|
||||
|
@ -507,9 +507,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU),
|
||||
GroupByWindowDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalGroupByWindowDataset").Device(DEVICE_CPU),
|
||||
GroupByWindowDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("GroupByWindowDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByWindowDataset");
|
||||
|
||||
} // namespace
|
||||
|
@ -140,6 +140,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
|
||||
IgnoreErrorsDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
|
||||
IgnoreErrorsDatasetOp);
|
||||
|
@ -216,6 +216,7 @@ class LMDBDatasetOp : public DatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
|
||||
LMDBDatasetOp);
|
||||
|
||||
|
@ -764,9 +764,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
bool preserve_cardinality_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU),
|
||||
MapAndBatchDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalMapAndBatchDataset").Device(DEVICE_CPU),
|
||||
MapAndBatchDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("MapAndBatchDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalMapAndBatchDataset");
|
||||
|
||||
} // namespace
|
||||
|
@ -366,6 +366,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MatchingFilesDataset").Device(DEVICE_CPU),
|
||||
MatchingFilesDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalMatchingFilesDataset").Device(DEVICE_CPU),
|
||||
MatchingFilesDatasetOp);
|
||||
|
@ -123,6 +123,8 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("NonSerializableDataset").Device(DEVICE_CPU),
|
||||
NonSerializableDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalNonSerializableDataset").Device(DEVICE_CPU),
|
||||
NonSerializableDatasetOp);
|
||||
|
@ -1069,9 +1069,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
|
||||
|
||||
} // namespace
|
||||
|
@ -397,6 +397,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<std::size_t> elements_per_stride_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
|
||||
ParseExampleDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalParseExampleDataset").Device(DEVICE_CPU),
|
||||
ParseExampleDatasetOp);
|
||||
|
@ -45,6 +45,8 @@ class IteratorGetDeviceOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
|
||||
IteratorGetDeviceOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
|
||||
IteratorGetDeviceOp);
|
||||
|
@ -154,6 +154,8 @@ class RandomDatasetOp : public DatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
|
||||
RandomDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalRandomDataset").Device(DEVICE_CPU),
|
||||
RandomDatasetOp);
|
||||
|
||||
|
@ -64,6 +64,8 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU),
|
||||
RebatchDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU),
|
||||
RebatchDatasetOp);
|
||||
|
||||
|
@ -290,9 +290,11 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
bool preserve_cardinality_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalScanDataset").Device(DEVICE_CPU),
|
||||
ScanDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ScanDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalScanDataset");
|
||||
|
||||
} // namespace
|
||||
|
@ -211,9 +211,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
|
||||
SetStatsAggregatorDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalSetStatsAggregatorDataset").Device(DEVICE_CPU),
|
||||
SetStatsAggregatorDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -133,8 +133,17 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SleepDataset").Device(DEVICE_CPU),
|
||||
SleepDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset").Device(DEVICE_CPU),
|
||||
SleepDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SleepDataset")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("sleep_microseconds")
|
||||
.HostMemory("input_dataset")
|
||||
.HostMemory("handle"),
|
||||
SleepDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("sleep_microseconds")
|
||||
|
@ -302,6 +302,8 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SlidingWindowDataset").Device(DEVICE_CPU),
|
||||
SlidingWindowDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalSlidingWindowDataset").Device(DEVICE_CPU),
|
||||
SlidingWindowDatasetOp);
|
||||
|
@ -214,6 +214,7 @@ class SqlDatasetOp : public DatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalSqlDataset").Device(DEVICE_CPU),
|
||||
SqlDatasetOp);
|
||||
|
||||
|
@ -296,14 +296,21 @@ class StatsAggregatorSetSummaryWriterOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandle").Device(DEVICE_CPU),
|
||||
StatsAggregatorHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
|
||||
StatsAggregatorHandleOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
|
||||
StatsAggregatorHandleOpV2);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
|
||||
StatsAggregatorSummaryOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
|
||||
StatsAggregatorSummaryOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
|
||||
StatsAggregatorSetSummaryWriterOp);
|
||||
|
@ -258,13 +258,18 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalLatencyStatsDataset").Device(DEVICE_CPU),
|
||||
LatencyStatsDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
|
||||
BytesProducedStatsDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalBytesProducedStatsDataset").Device(DEVICE_CPU),
|
||||
BytesProducedStatsDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
|
||||
LatencyStatsDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalLatencyStatsDataset").Device(DEVICE_CPU),
|
||||
LatencyStatsDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -198,8 +198,12 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TakeWhileDataset").Device(DEVICE_CPU),
|
||||
TakeWhileDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalTakeWhileDataset").Device(DEVICE_CPU),
|
||||
TakeWhileDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("TakeWhileDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalTakeWhileDataset");
|
||||
|
||||
} // namespace
|
||||
|
@ -431,14 +431,25 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxIntraOpParallelismDataset").Device(DEVICE_CPU),
|
||||
MaxIntraOpParallelismDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalMaxIntraOpParallelismDataset").Device(DEVICE_CPU),
|
||||
MaxIntraOpParallelismDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("PrivateThreadPoolDataset").Device(DEVICE_CPU),
|
||||
PrivateThreadPoolDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalPrivateThreadPoolDataset").Device(DEVICE_CPU),
|
||||
PrivateThreadPoolDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
|
||||
ThreadPoolHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
|
||||
ThreadPoolHandleOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
|
||||
ThreadPoolDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
|
||||
ThreadPoolDatasetOp);
|
||||
|
@ -100,6 +100,8 @@ class ToTFRecordOp : public AsyncOpKernel {
|
||||
BackgroundWorker background_worker_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
|
||||
ToTFRecordOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalDatasetToTFRecord").Device(DEVICE_CPU), ToTFRecordOp);
|
||||
|
||||
|
@ -221,6 +221,8 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU),
|
||||
UnbatchDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalUnbatchDataset").Device(DEVICE_CPU),
|
||||
UnbatchDatasetOp);
|
||||
|
||||
|
@ -221,6 +221,8 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
|
||||
UniqueDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
|
||||
UniqueDatasetOp);
|
||||
|
||||
|
@ -25428,18 +25428,6 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIdentityIndexedDataset"
|
||||
input_arg {
|
||||
name: "size"
|
||||
type: DT_UINT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIgnoreErrorsDataset"
|
||||
input_arg {
|
||||
@ -25463,46 +25451,6 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIndexedDatasetGet"
|
||||
input_arg {
|
||||
name: "materialized"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "index"
|
||||
type: DT_UINT64
|
||||
}
|
||||
output_arg {
|
||||
name: "components"
|
||||
type_list_attr: "output_types"
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIndexedDatasetMaterialize"
|
||||
input_arg {
|
||||
name: "dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "materialized"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIteratorGetDevice"
|
||||
input_arg {
|
||||
@ -25774,34 +25722,6 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalMaterializedIndexDatasetHandle"
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
attr {
|
||||
name: "container"
|
||||
type: "string"
|
||||
}
|
||||
attr {
|
||||
name: "shared_name"
|
||||
type: "string"
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalMaxIntraOpParallelismDataset"
|
||||
input_arg {
|
||||
@ -25852,109 +25772,6 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalNumaMapAndBatchDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "other_arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "batch_size"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "num_parallel_calls"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "drop_remainder"
|
||||
type: DT_BOOL
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalNumaMapAndBatchDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "other_arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "batch_size"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "num_parallel_calls"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "drop_remainder"
|
||||
type: DT_BOOL
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "preserve_cardinality"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalParallelInterleaveDataset"
|
||||
input_arg {
|
||||
|
@ -23524,18 +23524,6 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIdentityIndexedDataset"
|
||||
input_arg {
|
||||
name: "size"
|
||||
type: DT_UINT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIgnoreErrorsDataset"
|
||||
input_arg {
|
||||
@ -23559,46 +23547,6 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIndexedDatasetGet"
|
||||
input_arg {
|
||||
name: "materialized"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "index"
|
||||
type: DT_UINT64
|
||||
}
|
||||
output_arg {
|
||||
name: "components"
|
||||
type_list_attr: "output_types"
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIndexedDatasetMaterialize"
|
||||
input_arg {
|
||||
name: "dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "materialized"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalIteratorGetDevice"
|
||||
input_arg {
|
||||
@ -23870,34 +23818,6 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalMaterializedIndexDatasetHandle"
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
attr {
|
||||
name: "container"
|
||||
type: "string"
|
||||
}
|
||||
attr {
|
||||
name: "shared_name"
|
||||
type: "string"
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalMaxIntraOpParallelismDataset"
|
||||
input_arg {
|
||||
@ -23948,109 +23868,6 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalNumaMapAndBatchDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "other_arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "batch_size"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "num_parallel_calls"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "drop_remainder"
|
||||
type: DT_BOOL
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalNumaMapAndBatchDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "other_arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "batch_size"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "num_parallel_calls"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "drop_remainder"
|
||||
type: DT_BOOL
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "preserve_cardinality"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalParallelInterleaveDataset"
|
||||
input_arg {
|
||||
|
@ -17,10 +17,40 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("StatsAggregatorSetSummaryWriter")
|
||||
.Input("stats_aggregator: resource")
|
||||
.Input("summary: resource")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
REGISTER_OP("AssertNextDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("transformations: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// transformations should be a vector.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalAssertNextDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("transformations: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// transformations should be a vector.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("AutoShardDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
.Input("index: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalAutoShardDataset")
|
||||
.Input("input_dataset: variant")
|
||||
@ -31,6 +61,18 @@ REGISTER_OP("ExperimentalAutoShardDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("BytesProducedStatsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("tag: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle tag_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalBytesProducedStatsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("tag: string")
|
||||
@ -57,6 +99,66 @@ REGISTER_OP("ChooseFastestBranchDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ChooseFastestDataset")
|
||||
.Input("input_datasets: N * variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("N: int >= 2")
|
||||
.Attr("num_experiments: int")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalChooseFastestDataset")
|
||||
.Input("input_datasets: N * variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("N: int >= 2")
|
||||
.Attr("num_experiments: int")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("CSVDataset")
|
||||
.Input("filenames: string")
|
||||
.Input("compression_type: string")
|
||||
.Input("buffer_size: int64")
|
||||
.Input("header: bool")
|
||||
.Input("field_delim: string")
|
||||
.Input("use_quote_delim: bool")
|
||||
.Input("na_value: string")
|
||||
.Input("select_cols: int64")
|
||||
.Input("record_defaults: output_types")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list({float,double,int32,int64,string}) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `filenames` must be a scalar or a vector.
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
|
||||
// `compression_type`, `buffer_size`, `header`, `field_delim`,
|
||||
// `use_quote_delim`, `na_value` must be scalars
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
|
||||
// `select_cols` must be a vector
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
|
||||
// `record_defaults` must be lists of scalars
|
||||
for (size_t i = 8; i < c->num_inputs(); ++i) {
|
||||
shape_inference::ShapeHandle v;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
|
||||
if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape of a default must be a length-0 or length-1 vector, or a "
|
||||
"scalar.");
|
||||
}
|
||||
}
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalCSVDataset")
|
||||
.Input("filenames: string")
|
||||
.Input("compression_type: string")
|
||||
@ -99,6 +201,11 @@ REGISTER_OP("ExperimentalCSVDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("DatasetCardinality")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("cardinality: int64")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalDatasetCardinality")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("cardinality: int64")
|
||||
@ -108,6 +215,13 @@ REGISTER_OP("ExperimentalDatasetCardinality")
|
||||
// implement a mechanism to determine whether `dataset` has a side-effect
|
||||
// and use it to decide whether to use a stateless or stateful version of this
|
||||
// op.
|
||||
REGISTER_OP("DatasetToTFRecord")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("filename: string")
|
||||
.Input("compression_type: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("ExperimentalDatasetToTFRecord")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("filename: string")
|
||||
@ -115,6 +229,22 @@ REGISTER_OP("ExperimentalDatasetToTFRecord")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("DenseToSparseBatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("batch_size: int64")
|
||||
.Input("row_shape: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// batch_size should be a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
// row_shape should be a 1-D vector.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalDenseToSparseBatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("batch_size: int64")
|
||||
@ -131,6 +261,15 @@ REGISTER_OP("ExperimentalDenseToSparseBatchDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("DirectedInterleaveDataset")
|
||||
.Input("selector_input_dataset: variant")
|
||||
.Input("data_input_datasets: N * variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("N: int >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalDirectedInterleaveDataset")
|
||||
.Input("selector_input_dataset: variant")
|
||||
.Input("data_input_datasets: N * variant")
|
||||
@ -140,6 +279,26 @@ REGISTER_OP("ExperimentalDirectedInterleaveDataset")
|
||||
.Attr("N: int >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("GroupByReducerDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("key_func_other_arguments: Tkey_func_other_arguments")
|
||||
.Input("init_func_other_arguments: Tinit_func_other_arguments")
|
||||
.Input("reduce_func_other_arguments: Treduce_func_other_arguments")
|
||||
.Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
|
||||
.Output("handle: variant")
|
||||
.Attr("key_func: func")
|
||||
.Attr("init_func: func")
|
||||
.Attr("reduce_func: func")
|
||||
.Attr("finalize_func: func")
|
||||
.Attr("Tkey_func_other_arguments: list(type) >= 0")
|
||||
.Attr("Tinit_func_other_arguments: list(type) >= 0")
|
||||
.Attr("Treduce_func_other_arguments: list(type) >= 0")
|
||||
.Attr("Tfinalize_func_other_arguments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalGroupByReducerDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("key_func_other_arguments: Tkey_func_other_arguments")
|
||||
@ -160,6 +319,23 @@ REGISTER_OP("ExperimentalGroupByReducerDataset")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("GroupByWindowDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("key_func_other_arguments: Tkey_func_other_arguments")
|
||||
.Input("reduce_func_other_arguments: Treduce_func_other_arguments")
|
||||
.Input(
|
||||
"window_size_func_other_arguments: Twindow_size_func_other_arguments")
|
||||
.Output("handle: variant")
|
||||
.Attr("key_func: func")
|
||||
.Attr("reduce_func: func")
|
||||
.Attr("window_size_func: func")
|
||||
.Attr("Tkey_func_other_arguments: list(type) >= 0")
|
||||
.Attr("Treduce_func_other_arguments: list(type) >= 0")
|
||||
.Attr("Twindow_size_func_other_arguments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalGroupByWindowDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("key_func_other_arguments: Tkey_func_other_arguments")
|
||||
@ -177,6 +353,13 @@ REGISTER_OP("ExperimentalGroupByWindowDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("IgnoreErrorsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalIgnoreErrorsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
@ -184,6 +367,28 @@ REGISTER_OP("ExperimentalIgnoreErrorsDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("IteratorGetDevice")
|
||||
.Input("resource: resource")
|
||||
.Output("device: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalIteratorGetDevice")
|
||||
.Input("resource: resource")
|
||||
.Output("device: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("LatencyStatsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("tag: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle tag_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalLatencyStatsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("tag: string")
|
||||
@ -196,6 +401,51 @@ REGISTER_OP("ExperimentalLatencyStatsDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("LMDBDataset")
|
||||
.Input("filenames: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalLMDBDataset")
|
||||
.Input("filenames: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("MapAndBatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("batch_size: int64")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Input("drop_remainder: bool")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("preserve_cardinality: bool = false")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// Use index from the end to retrieve the Input shapes,
|
||||
// so that to avoid guessing the length of "other_arguments".
|
||||
// batch_size, num_parallel_calls, and drop_remainder are 0-D scalars.
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
|
||||
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalMapAndBatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
@ -223,14 +473,6 @@ REGISTER_OP("ExperimentalMapAndBatchDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalRebatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalMapDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
@ -243,6 +485,18 @@ REGISTER_OP("ExperimentalMapDataset")
|
||||
.Attr("preserve_cardinality: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("MatchingFilesDataset")
|
||||
.Input("patterns: string")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `patterns` must be a scalar or a vector.
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalMatchingFilesDataset")
|
||||
.Input("patterns: string")
|
||||
.Output("handle: variant")
|
||||
@ -255,6 +509,29 @@ REGISTER_OP("ExperimentalMatchingFilesDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("MaxIntraOpParallelismDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("max_intra_op_parallelism: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("max_intra_op_parallelism: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("NonSerializableDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalNonSerializableDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
@ -262,6 +539,21 @@ REGISTER_OP("ExperimentalNonSerializableDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ParallelInterleaveDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("cycle_length: int64")
|
||||
.Input("block_length: int64")
|
||||
.Input("sloppy: bool")
|
||||
.Input("buffer_output_elements: int64")
|
||||
.Input("prefetch_input_elements: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalParallelInterleaveDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
@ -277,6 +569,23 @@ REGISTER_OP("ExperimentalParallelInterleaveDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ParseExampleDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Input("dense_defaults: Tdense")
|
||||
.Output("handle: variant")
|
||||
.Attr("sparse_keys: list(string) >= 0")
|
||||
.Attr("dense_keys: list(string) >= 0")
|
||||
.Attr("sparse_types: list({float,int64,string}) >= 0")
|
||||
.Attr("Tdense: list({float,int64,string}) >= 0")
|
||||
.Attr("dense_shapes: list(shape) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1") // Output components will be
|
||||
// sorted by key (dense_keys and
|
||||
// sparse_keys combined) here.
|
||||
.Attr("sloppy: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalParseExampleDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_parallel_calls: int64")
|
||||
@ -294,6 +603,22 @@ REGISTER_OP("ExperimentalParseExampleDataset")
|
||||
.Attr("sloppy: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("PrivateThreadPoolDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_threads: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_threads: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalRandomDataset")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
@ -310,6 +635,68 @@ REGISTER_OP("ExperimentalRandomDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("RandomDataset")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, and seed2 should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalRebatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("RebatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("SamplingDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("rate: float32")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// rate, seed, and seed2 should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ScanDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("initial_state: Tstate")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Tstate: list(type) >= 1")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("preserve_cardinality: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalScanDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("initial_state: Tstate")
|
||||
@ -323,6 +710,16 @@ REGISTER_OP("ExperimentalScanDataset")
|
||||
.Attr("preserve_cardinality: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("SetStatsAggregatorDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("stats_aggregator: resource")
|
||||
.Input("tag: string")
|
||||
.Input("counter_prefix: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalSetStatsAggregatorDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("stats_aggregator: resource")
|
||||
@ -333,6 +730,20 @@ REGISTER_OP("ExperimentalSetStatsAggregatorDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("SleepDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("sleep_microseconds: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// Both inputs are scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalSleepDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("sleep_microseconds: int64")
|
||||
@ -347,6 +758,23 @@ REGISTER_OP("ExperimentalSleepDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("SlidingWindowDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("window_size: int64")
|
||||
.Input("window_shift: int64")
|
||||
.Input("window_stride: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// window_size, window_shift, and window_stride should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalSlidingWindowDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("window_size: int64")
|
||||
@ -382,6 +810,24 @@ REGISTER_OP("SnapshotDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("SqlDataset")
|
||||
.Input("driver_name: string")
|
||||
.Input("data_source_name: string")
|
||||
.Input("query: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// driver_name, data_source_name, and query should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalSqlDataset")
|
||||
.Input("driver_name: string")
|
||||
.Input("data_source_name: string")
|
||||
@ -400,6 +846,12 @@ REGISTER_OP("ExperimentalSqlDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("StatsAggregatorHandle")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''");
|
||||
|
||||
REGISTER_OP("ExperimentalStatsAggregatorHandle")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
@ -412,11 +864,31 @@ REGISTER_OP("StatsAggregatorHandleV2")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''");
|
||||
|
||||
REGISTER_OP("StatsAggregatorSetSummaryWriter")
|
||||
.Input("stats_aggregator: resource")
|
||||
.Input("summary: resource")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("StatsAggregatorSummary")
|
||||
.Input("iterator: resource")
|
||||
.Output("summary: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalStatsAggregatorSummary")
|
||||
.Input("iterator: resource")
|
||||
.Output("summary: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("TakeWhileDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Output("handle: variant")
|
||||
.Attr("predicate: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalTakeWhileDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
@ -427,36 +899,9 @@ REGISTER_OP("ExperimentalTakeWhileDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalUnbatchDataset")
|
||||
REGISTER_OP("ThreadPoolDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalUniqueDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalIteratorGetDevice")
|
||||
.Input("resource: resource")
|
||||
.Output("device: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("max_intra_op_parallelism: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_threads: int64")
|
||||
.Input("thread_pool: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
@ -470,6 +915,15 @@ REGISTER_OP("ExperimentalThreadPoolDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ThreadPoolHandle")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Attr("num_threads: int")
|
||||
.Attr("max_intra_op_parallelism: int = 1")
|
||||
.Attr("display_name: string")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''");
|
||||
|
||||
REGISTER_OP("ExperimentalThreadPoolHandle")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
@ -479,136 +933,32 @@ REGISTER_OP("ExperimentalThreadPoolHandle")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''");
|
||||
|
||||
REGISTER_OP("ExperimentalAssertNextDataset")
|
||||
REGISTER_OP("UnbatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("transformations: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// transformations should be a vector.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalNumaMapAndBatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("batch_size: int64")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Input("drop_remainder: bool")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("preserve_cardinality: bool = false")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// Use index from the end to retrieve the Input shapes,
|
||||
// so that to avoid guessing the length of "other_arguments".
|
||||
// batch_size, num_parallel_batches, and drop_remainder are 0-D scalars.
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
|
||||
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalLMDBDataset")
|
||||
.Input("filenames: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalChooseFastestDataset")
|
||||
.Input("input_datasets: N * variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("N: int >= 2")
|
||||
.Attr("num_experiments: int")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalIdentityIndexedDataset")
|
||||
.Input("size: uint64")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(
|
||||
shape_inference::ScalarShape); // TODO(saeta): check input shapes.
|
||||
|
||||
REGISTER_OP("SamplingDataset")
|
||||
REGISTER_OP("ExperimentalUnbatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("rate: float32")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// rate, seed, and seed2 should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// IndexedDataset Internals
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Creates the handle.
|
||||
REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
|
||||
.Output("handle: resource")
|
||||
.Attr("container: string")
|
||||
.Attr("shared_name: string")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// Actually materialize the materialize handle.
|
||||
REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
|
||||
.Input("dataset: variant")
|
||||
.Input("materialized: resource")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
namespace {
|
||||
|
||||
Status GetShapeFn(shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
|
||||
if (output_shapes.size() != c->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"`output_shapes` must be the same length as `output_types` (",
|
||||
output_shapes.size(), " vs. ", c->num_outputs());
|
||||
}
|
||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||
shape_inference::ShapeHandle output_shape_handle;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
|
||||
output_shapes[i], &output_shape_handle));
|
||||
c->set_output(static_cast<int>(i), output_shape_handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("ExperimentalIndexedDatasetGet")
|
||||
.Input("materialized: resource")
|
||||
.Input("index: uint64")
|
||||
.Output("components: output_types")
|
||||
REGISTER_OP("UniqueDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(GetShapeFn);
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalUniqueDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.util import nest
|
||||
@ -246,11 +247,18 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
||||
dataset_ops.get_legacy_output_types(input_dataset),
|
||||
tensor_shape.vector(None).concatenate(self._row_shape))
|
||||
|
||||
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._batch_size,
|
||||
row_shape=convert.partial_shape_to_tensor(self._row_shape),
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._batch_size,
|
||||
row_shape=convert.partial_shape_to_tensor(self._row_shape),
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._batch_size,
|
||||
row_shape=convert.partial_shape_to_tensor(self._row_shape),
|
||||
**self._flat_structure)
|
||||
super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@ -294,15 +302,26 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
lambda component_spec: component_spec._batch(None),
|
||||
self._map_func.output_structure)
|
||||
# pylint: enable=protected-access
|
||||
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
batch_size=self._batch_size_t,
|
||||
num_parallel_calls=self._num_parallel_calls_t,
|
||||
drop_remainder=self._drop_remainder_t,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.map_and_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
batch_size=self._batch_size_t,
|
||||
num_parallel_calls=self._num_parallel_calls_t,
|
||||
drop_remainder=self._drop_remainder_t,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
batch_size=self._batch_size_t,
|
||||
num_parallel_calls=self._num_parallel_calls_t,
|
||||
drop_remainder=self._drop_remainder_t,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -47,4 +48,8 @@ def cardinality(dataset):
|
||||
the cardinality is infinite or unknown, the operation returns the named
|
||||
constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
|
||||
"""
|
||||
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
else:
|
||||
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -47,11 +48,18 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._element_spec = input_dataset.element_spec
|
||||
variant_tensor = ged_ops.experimental_auto_shard_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
index=index,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.auto_shard_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
index=index,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_auto_shard_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
index=index,
|
||||
**self._flat_structure)
|
||||
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
@ -87,10 +95,16 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
input_types, output_shapes, input_classes)
|
||||
variant_tensor = ged_ops.experimental_rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_rebatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
num_workers=num_workers,
|
||||
**self._flat_structure)
|
||||
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -59,8 +60,14 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset):
|
||||
"""See `Dataset.ignore_errors()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.ignore_errors_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -253,17 +254,30 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
||||
self._make_init_func(reducer.init_func)
|
||||
self._make_reduce_func(reducer.reduce_func, input_dataset)
|
||||
self._make_finalize_func(reducer.finalize_func)
|
||||
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.group_by_reducer_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**self._flat_structure)
|
||||
super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _make_key_func(self, key_func, input_dataset):
|
||||
@ -375,15 +389,26 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
self._make_key_func(key_func, input_dataset)
|
||||
self._make_reduce_func(reduce_func, input_dataset)
|
||||
self._make_window_size_func(window_size_func)
|
||||
variant_tensor = ged_ops.experimental_group_by_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.group_by_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_group_by_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**self._flat_structure)
|
||||
super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _make_window_size_func(self, window_size_func):
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import random_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
@ -125,11 +126,18 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return (
|
||||
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return (
|
||||
gen_experimental_dataset_ops.directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
else:
|
||||
return (
|
||||
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**self._flat_structure))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _inputs(self):
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,7 +32,11 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
|
||||
def __init__(self, patterns):
|
||||
self._patterns = ops.convert_to_tensor(
|
||||
patterns, dtype=dtypes.string, name="patterns")
|
||||
variant_tensor = ged_ops.experimental_matching_files_dataset(self._patterns)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.matching_files_dataset(self._patterns)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_matching_files_dataset(
|
||||
self._patterns)
|
||||
super(MatchingFilesDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -104,11 +105,18 @@ class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
raise ValueError("At least one transformation should be specified")
|
||||
self._transformations = ops.convert_to_tensor(
|
||||
transformations, dtype=dtypes.string, name="transformations")
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_assert_next_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**self._flat_structure))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.assert_next_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_assert_next_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**self._flat_structure))
|
||||
super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
@ -118,10 +126,16 @@ class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset):
|
||||
"""See `non_serializable()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.non_serializable_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure))
|
||||
super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
@ -157,11 +171,18 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
||||
"""
|
||||
self._datasets = list(datasets)
|
||||
self._element_spec = self._datasets[0].element_spec
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
|
||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||
num_experiments=num_experiments,
|
||||
**self._flat_structure))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.choose_fastest_dataset(
|
||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||
num_experiments=num_experiments,
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
|
||||
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
|
||||
num_experiments=num_experiments,
|
||||
**self._flat_structure))
|
||||
super(_ChooseFastestDataset, self).__init__(variant_tensor)
|
||||
|
||||
def _inputs(self):
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -79,16 +80,28 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
self._element_spec = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**self._flat_structure))
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**self._flat_structure))
|
||||
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import random_seed
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -34,8 +35,12 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
|
||||
def __init__(self, seed=None):
|
||||
"""A `Dataset` of pseudorandom values."""
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
|
||||
seed=self._seed, seed2=self._seed2, **self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.random_dataset(
|
||||
seed=self._seed, seed2=self._seed2, **self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
|
||||
seed=self._seed, seed2=self._seed2, **self._flat_structure)
|
||||
super(RandomDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -23,6 +23,7 @@ import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
@ -664,17 +665,30 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
||||
)
|
||||
self._element_spec = tuple(
|
||||
structure.TensorStructure(d.dtype, []) for d in self._record_defaults)
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
|
||||
filenames=self._filenames,
|
||||
record_defaults=self._record_defaults,
|
||||
buffer_size=self._buffer_size,
|
||||
header=self._header,
|
||||
output_shapes=self._flat_shapes,
|
||||
field_delim=self._field_delim,
|
||||
use_quote_delim=self._use_quote_delim,
|
||||
na_value=self._na_value,
|
||||
select_cols=self._select_cols,
|
||||
compression_type=self._compression_type)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.csv_dataset(
|
||||
filenames=self._filenames,
|
||||
record_defaults=self._record_defaults,
|
||||
buffer_size=self._buffer_size,
|
||||
header=self._header,
|
||||
output_shapes=self._flat_shapes,
|
||||
field_delim=self._field_delim,
|
||||
use_quote_delim=self._use_quote_delim,
|
||||
na_value=self._na_value,
|
||||
select_cols=self._select_cols,
|
||||
compression_type=self._compression_type)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
|
||||
filenames=self._filenames,
|
||||
record_defaults=self._record_defaults,
|
||||
buffer_size=self._buffer_size,
|
||||
header=self._header,
|
||||
output_shapes=self._flat_shapes,
|
||||
field_delim=self._field_delim,
|
||||
use_quote_delim=self._use_quote_delim,
|
||||
na_value=self._na_value,
|
||||
select_cols=self._select_cols,
|
||||
compression_type=self._compression_type)
|
||||
super(CsvDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
@ -957,9 +971,14 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
||||
query, dtype=dtypes.string, name="query")
|
||||
self._element_spec = nest.map_structure(
|
||||
lambda dtype: structure.TensorStructure(dtype, []), output_types)
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
**self._flat_structure)
|
||||
super(SqlDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -122,13 +123,22 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
self._scan_func = wrapped_func
|
||||
self._scan_func.function.add_to_graph(ops.get_default_graph())
|
||||
# pylint: disable=protected-access
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
self._scan_func.function.captured_inputs,
|
||||
f=self._scan_func.function,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
self._scan_func.function.captured_inputs,
|
||||
f=self._scan_func.function,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
self._scan_func.function.captured_inputs,
|
||||
f=self._scan_func.function,
|
||||
preserve_cardinality=True,
|
||||
**self._flat_structure)
|
||||
super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
|
||||
@ -27,10 +28,16 @@ class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset, sleep_microseconds):
|
||||
self._input_dataset = input_dataset
|
||||
self._sleep_microseconds = sleep_microseconds
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._sleep_microseconds,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._sleep_microseconds,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._sleep_microseconds,
|
||||
**self._flat_structure)
|
||||
super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -125,7 +126,10 @@ class StatsAggregatorV1(object):
|
||||
|
||||
def __init__(self):
|
||||
"""Creates a `StatsAggregator`."""
|
||||
self._resource = ged_ops.experimental_stats_aggregator_handle()
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
self._resource = ged_ops.stats_aggregator_handle()
|
||||
else:
|
||||
self._resource = ged_ops.experimental_stats_aggregator_handle()
|
||||
|
||||
def get_summary(self):
|
||||
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
|
||||
@ -137,7 +141,10 @@ class StatsAggregatorV1(object):
|
||||
Returns:
|
||||
A scalar string `tf.Tensor` that summarizes the aggregated statistics.
|
||||
"""
|
||||
return ged_ops.experimental_stats_aggregator_summary(self._resource)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return ged_ops.stats_aggregator_summary(self._resource)
|
||||
else:
|
||||
return ged_ops.experimental_stats_aggregator_summary(self._resource)
|
||||
|
||||
|
||||
# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -65,10 +66,14 @@ def bytes_produced_stats(tag):
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return _StatsDataset(
|
||||
dataset,
|
||||
gen_experimental_dataset_ops.experimental_bytes_produced_stats_dataset,
|
||||
tag)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return _StatsDataset(
|
||||
dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset,
|
||||
tag)
|
||||
else:
|
||||
return _StatsDataset(
|
||||
dataset, gen_experimental_dataset_ops
|
||||
.experimental_bytes_produced_stats_dataset, tag)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
@ -90,9 +95,14 @@ def latency_stats(tag):
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return _StatsDataset(
|
||||
dataset,
|
||||
gen_experimental_dataset_ops.experimental_latency_stats_dataset, tag)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return _StatsDataset(
|
||||
dataset,
|
||||
gen_experimental_dataset_ops.latency_stats_dataset, tag)
|
||||
else:
|
||||
return _StatsDataset(
|
||||
dataset,
|
||||
gen_experimental_dataset_ops.experimental_latency_stats_dataset, tag)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import structure as structure_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -41,11 +42,18 @@ class _TakeWhileDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
raise ValueError("`predicate` must return a scalar boolean tensor.")
|
||||
|
||||
self._predicate = wrapped_func
|
||||
var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
other_arguments=self._predicate.function.captured_inputs,
|
||||
predicate=self._predicate.function,
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
var_tensor = gen_experimental_dataset_ops.take_while_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
other_arguments=self._predicate.function.captured_inputs,
|
||||
predicate=self._predicate.function,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
other_arguments=self._predicate.function.captured_inputs,
|
||||
predicate=self._predicate.function,
|
||||
**self._flat_structure)
|
||||
super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
|
||||
|
||||
def _functions(self):
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
@ -46,18 +47,31 @@ class PrivateThreadPool(object):
|
||||
"""Creates a `PrivateThreadPool` with the given number of threads."""
|
||||
if context.executing_eagerly():
|
||||
shared_name = _generate_shared_name("privatethreadpool")
|
||||
self._resource = ged_ops.experimental_thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name,
|
||||
shared_name=shared_name)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
self._resource = ged_ops.thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name,
|
||||
shared_name=shared_name)
|
||||
else:
|
||||
self._resource = ged_ops.experimental_thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name,
|
||||
shared_name=shared_name)
|
||||
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||
handle=self._resource, handle_device=context.context().device_name)
|
||||
else:
|
||||
self._resource = ged_ops.experimental_thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
self._resource = ged_ops.thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name)
|
||||
else:
|
||||
self._resource = ged_ops.experimental_thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name=display_name)
|
||||
|
||||
|
||||
class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
@ -66,10 +80,16 @@ class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
def __init__(self, input_dataset, thread_pool):
|
||||
self._input_dataset = input_dataset
|
||||
self._thread_pool = thread_pool
|
||||
variant_tensor = ged_ops.experimental_thread_pool_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._thread_pool._resource, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = ged_ops.thread_pool_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._thread_pool._resource, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.experimental_thread_pool_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._thread_pool._resource, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
@ -59,7 +60,12 @@ class _UniqueDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
raise TypeError(
|
||||
"`tf.data.experimental.unique()` only supports inputs with a single "
|
||||
"`tf.int32`, `tf.int64`, or `tf.string` component.")
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
variant_tensor = gen_experimental_dataset_ops.unique_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**self._flat_structure)
|
||||
super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.util import structure
|
||||
@ -83,5 +84,9 @@ class TFRecordWriter(object):
|
||||
"produces shape {0} and types {1}".format(
|
||||
dataset_ops.get_legacy_output_shapes(dataset),
|
||||
dataset_ops.get_legacy_output_types(dataset)))
|
||||
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
|
||||
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
if compat.forward_compatible(2019, 8, 3):
|
||||
return gen_experimental_dataset_ops.dataset_to_tf_record(
|
||||
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
else:
|
||||
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
|
||||
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user