[tf.data] Cleanup of tf.data op definitions.

cl/223710135 renamed some datasets ops (adding "Experimental" prefix), which introduced a backwards compatibility issue. In particular, if a graph was saved using TensorFlow before cl/223710135, restoring the graph using TensorFlow built after cl/223710135 would fail if the graph contained any of the renamed ops. To address this issue, this CL reintroduces op definitions removed by cl/223710135 and uses `compat.forward_compatible` to transition away from ops that use the (misleading) "Experimental" prefix to the corresponding ones without the prefix.

In addition, this CL removes tf.data op definitions for ops that have no kernels.

PiperOrigin-RevId: 256436076
This commit is contained in:
Jiri Simsa 2019-07-03 14:12:29 -07:00 committed by TensorFlower Gardener
parent 4cca406165
commit de9c460b06
104 changed files with 1907 additions and 792 deletions

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
@ -390,8 +391,12 @@ class LMDBDataset(dataset_ops.DatasetSource):
"""
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames, **self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.lmdb_dataset(
self._filenames, **self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames, **self._flat_structure)
super(LMDBDataset, self).__init__(variant_tensor)
@property

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@ -41,12 +42,20 @@ class _SlideDataset(dataset_ops.UnaryDataset):
input_structure = dataset_ops.get_structure(input_dataset)
self._element_spec = nest.map_structure(
lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_sliding_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
window_size=self._window_size,
window_shift=self._window_shift,
window_stride=self._window_stride,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.sliding_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
window_size=self._window_size,
window_shift=self._window_shift,
window_stride=self._window_stride,
**self._flat_structure)
else:
variant_tensor = ged_ops.experimental_sliding_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
window_size=self._window_size,
window_shift=self._window_shift,
window_stride=self._window_stride,
**self._flat_structure)
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
@property

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "AssertNextDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,32 @@
op {
graph_op_name: "AutoShardDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the input dataset.
END
}
in_arg {
name: "num_workers"
description: <<END
A scalar representing the number of workers to distribute this dataset across.
END
}
in_arg {
name: "index"
description: <<END
A scalar representing the index of the current worker out of num_workers.
END
}
summary: "Creates a dataset that shards the input dataset."
description: <<END
Creates a dataset that shards the input dataset by num_workers, returning a
sharded dataset for the index-th worker. This attempts to automatically shard
a dataset by examining the Dataset graph and inserting a shard op before the
inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset).
This dataset will throw a NotFound error if we cannot shard the dataset
automatically.
END
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "BytesProducedStatsDataset"
visibility: HIDDEN
summary: "Records the bytes size of each element of `input_dataset` in a StatsAggregator."
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "CSVDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "ChooseFastestDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,21 @@
op {
graph_op_name: "DatasetCardinality"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the dataset to return cardinality for.
END
}
out_arg {
name: "cardinality"
description: <<END
The cardinality of `input_dataset`. Named constants are used to represent
infinite and unknown cardinality.
END
}
summary: "Returns the cardinality of `input_dataset`."
description: <<END
Returns the cardinality of `input_dataset`.
END
}

View File

@ -0,0 +1,24 @@
op {
graph_op_name: "DatasetToTFRecord"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the dataset to write.
END
}
in_arg {
name: "filename"
description: <<END
A scalar string tensor representing the filename to use.
END
}
in_arg {
name: "compression_type"
description: <<END
A scalar string tensor containing either (i) the empty string (no
compression), (ii) "ZLIB", or (iii) "GZIP".
END
}
summary: "Writes the given dataset to the given file using the TFRecord format."
}

View File

@ -0,0 +1,26 @@
op {
graph_op_name: "DenseToSparseBatchDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A handle to an input dataset. Must have a single component.
END
}
in_arg {
name: "batch_size"
description: <<END
A scalar representing the number of elements to accumulate in a
batch.
END
}
in_arg {
name: "row_shape"
description: <<END
A vector representing the dense shape of each row in the produced
SparseTensor. The shape may be partially specified, using `-1` to indicate
that a particular dimension should use the maximum size of all batch elements.
END
}
summary: "Creates a dataset that batches input elements into a SparseTensor."
}

View File

@ -0,0 +1,21 @@
op {
graph_op_name: "DirectedInterleaveDataset"
in_arg {
name: "selector_input_dataset"
description: <<END
A dataset of scalar `DT_INT64` elements that determines which of the
`N` data inputs should produce the next output element.
END
}
in_arg {
name: "data_input_datasets"
description: <<END
`N` datasets with the same type that will be interleaved according to
the values of `selector_input_dataset`.
END
}
summary: <<END
A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
END
visibility: HIDDEN
}

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "ExperimentalIdentityIndexedDataset"
visibility: HIDDEN
}

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "ExperimentalIndexedDatasetGet"
visibility: HIDDEN
}

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "ExperimentalIndexedDatasetMaterialize"
visibility: HIDDEN
}

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
visibility: HIDDEN
}

View File

@ -0,0 +1,69 @@
op {
graph_op_name: "GroupByReducerDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the input dataset.
END
}
in_arg {
name: "key_func_other_arguments"
description: <<END
A list of tensors, typically values that were captured when
building a closure for `key_func`.
END
}
attr {
name: "key_func"
description: <<END
A function mapping an element of `input_dataset`, concatenated
with `key_func_other_arguments` to a scalar value of type DT_INT64.
END
}
in_arg {
name: "init_func_other_arguments"
description: <<END
A list of tensors, typically values that were captured when
building a closure for `init_func`.
END
}
attr {
name: "init_func"
description: <<END
A function mapping a key of type DT_INT64, concatenated with
`init_func_other_arguments` to the initial reducer state.
END
}
in_arg {
name: "reduce_func_other_arguments"
description: <<END
A list of tensors, typically values that were captured when
building a closure for `reduce_func`.
END
}
attr {
name: "reduce_func"
description: <<END
A function mapping the current reducer state and an element of `input_dataset`,
concatenated with `reduce_func_other_arguments` to a new reducer state.
END
}
in_arg {
name: "finalize_func_other_arguments"
description: <<END
A list of tensors, typically values that were captured when
building a closure for `finalize_func`.
END
}
attr {
name: "finalize_func"
description: <<END
A function mapping the final reducer state to an output element.
END
}
summary: "Creates a dataset that computes a group-by on `input_dataset`."
description: <<END
Creates a dataset that computes a group-by on `input_dataset`.
END
}

View File

@ -0,0 +1,15 @@
op {
graph_op_name: "GroupByWindowDataset"
visibility: HIDDEN
attr {
name: "key_func"
description: <<END
A function mapping an element of `input_dataset`, concatenated
with `key_func_other_arguments` to a scalar value of type DT_INT64.
END
}
summary: "Creates a dataset that computes a windowed group-by on `input_dataset`."
description: <<END
// TODO(mrry): Support non-int64 keys.
END
}

View File

@ -0,0 +1,8 @@
op {
graph_op_name: "IgnoreErrorsDataset"
summary: <<END
Creates a dataset that contains the elements of `input_dataset` ignoring errors.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,8 @@
op {
graph_op_name: "IteratorGetDevice"
summary: <<END
Returns the name of the device on which `resource` has been placed.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "LMDBDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "LatencyStatsDataset"
visibility: HIDDEN
summary: "Records the latency of producing `input_dataset` elements in a StatsAggregator."
}

View File

@ -1,5 +1,5 @@
op {
graph_op_name: "ExperimentalNumaMapAndBatchDataset"
graph_op_name: "MapAndBatchDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
@ -50,9 +50,5 @@ batches `batch_size` of them.
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
to `batch_size * num_parallel_batches` copies of `f` in parallel.
Unlike "MapAndBatchDatasetV2", this dataset uses a NUMA-aware thread scheduling
policy. Because it uses the single-threaded executor, it only supports the
function-based control flow ops.
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "MatchingFilesDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,13 @@
op {
graph_op_name: "MaxIntraOpParallelismDataset"
in_arg {
name: "max_intra_op_parallelism"
description: <<END
Identifies the maximum intra-op parallelism to use.
END
}
summary: <<END
Creates a dataset that overrides the maximum intra-op parallelism.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "NonSerializableDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,22 @@
op {
graph_op_name: "ParallelInterleaveDataset"
visibility: HIDDEN
attr {
name: "f"
description: <<END
A function mapping elements of `input_dataset`, concatenated with
`other_arguments`, to a Dataset variant that contains elements matching
`output_types` and `output_shapes`.
END
}
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
description: <<END
The resulting dataset is similar to the `InterleaveDataset`, with the exception
that if retrieving the next value from a dataset would cause the requester to
block, it will skip that input dataset. This dataset is especially useful
when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
allows the training step to proceed so long as some data is available.
!! WARNING !! This dataset is not deterministic!
END
}

View File

@ -0,0 +1,70 @@
op {
graph_op_name: "ParseExampleDataset"
visibility: HIDDEN
in_arg {
name: "dense_defaults"
description: <<END
A dict mapping string keys to `Tensor`s.
The keys of the dict must match the dense_keys of the feature.
END
}
attr {
name: "sparse_keys"
description: <<END
A list of string keys in the examples features.
The results for these keys will be returned as `SparseTensor` objects.
END
}
attr {
name: "dense_keys"
description: <<END
A list of Ndense string Tensors (scalars).
The keys expected in the Examples features associated with dense values.
END
}
attr {
name: "sparse_types"
description: <<END
A list of `DTypes` of the same length as `sparse_keys`.
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
and `tf.string` (`BytesList`) are supported.
END
}
attr {
name: "Tdense"
description: <<END
A list of DTypes of the same length as `dense_keys`.
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
and `tf.string` (`BytesList`) are supported.
END
}
attr {
name: "dense_shapes"
description: <<END
List of tuples with the same length as `dense_keys`.
The shape of the data for each dense feature referenced by `dense_keys`.
Required for any input tensors identified by `dense_keys`. Must be
either fully defined, or may contain an unknown first dimension.
An unknown first dimension means the feature is treated as having
a variable number of blocks, and the output shape along this dimension
is considered unknown at graph build time. Padding is applied for
minibatch elements smaller than the maximum number of blocks for the
given feature along this dimension.
END
}
attr {
name: "output_types"
description: <<END
The type list for the return values.
END
}
attr {
name: "output_shapes"
description: <<END
The list of shapes being produced.
END
}
summary: "Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features."
}

View File

@ -0,0 +1,13 @@
op {
graph_op_name: "PrivateThreadPoolDataset"
in_arg {
name: "num_threads"
description: <<END
Identifies the number of threads to use for the private threadpool.
END
}
summary: <<END
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,19 @@
op {
graph_op_name: "RandomDataset"
visibility: HIDDEN
in_arg {
name: "seed"
description: <<END
A scalar seed for the random number generator. If either seed or
seed2 is set to be non-zero, the random number generator is seeded
by the given seed. Otherwise, a random seed is used.
END
}
in_arg {
name: "seed2"
description: <<END
A second scalar seed to avoid seed collision.
END
}
summary: "Creates a Dataset that returns pseudorandom numbers."
}

View File

@ -0,0 +1,23 @@
op {
graph_op_name: "RebatchDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the input dataset.
END
}
in_arg {
name: "num_workers"
description: <<END
A scalar representing the number of workers to distribute this batch across. As
a result of this transformation the current batch size would end up being
divided by this parameter.
END
}
summary: "Creates a dataset that changes the batch size."
description: <<END
Creates a dataset that changes the batch size of the dataset to current batch
size // num_workers.
END
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "ScanDataset"
visibility: HIDDEN
summary: "Creates a dataset successively reduces `f` over the elements of `input_dataset`."
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SetStatsAggregatorDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SleepDataset"
visibility: HIDDEN
}

View File

@ -0,0 +1,26 @@
op {
graph_op_name: "SlidingWindowDataset"
visibility: HIDDEN
in_arg {
name: "window_size"
description: <<END
A scalar representing the number of elements in the
sliding window.
END
}
in_arg {
name: "window_shift"
description: <<END
A scalar representing the steps moving the sliding window
forward in one iteration. It must be positive.
END
}
in_arg {
name: "window_stride"
description: <<END
A scalar representing the stride of the input elements of the sliding window.
It must be positive.
END
}
summary: "Creates a dataset that passes a sliding window over `input_dataset`."
}

View File

@ -0,0 +1,23 @@
op {
graph_op_name: "SqlDataset"
visibility: HIDDEN
in_arg {
name: "driver_name"
description: <<END
The database type. Currently, the only supported type is 'sqlite'.
END
}
in_arg {
name: "data_source_name"
description: <<END
A connection string to connect to the database.
END
}
in_arg {
name: "query"
description: <<END
A SQL query to execute.
END
}
summary: "Creates a dataset that executes a SQL query and emits rows of the result set."
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "StatsAggregatorHandle"
visibility: HIDDEN
summary: "Creates a statistics manager resource."
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "StatsAggregatorSummary"
visibility: HIDDEN
summary: "Produces a summary of any statistics recorded by the given statistics manager."
}

View File

@ -0,0 +1,25 @@
op {
graph_op_name: "TakeWhileDataset"
visibility: HIDDEN
in_arg {
name: "other_arguments"
description: <<END
A list of tensors, typically values that were captured when
building a closure for `predicate`.
END
}
attr {
name: "predicate"
description: <<END
A function returning a scalar boolean.
END
}
summary: "Creates a dataset that stops iteration when predicate` is false."
description: <<END
The `predicate` function must return a scalar boolean and accept the
following arguments:
* One tensor for each component of an element of `input_dataset`.
* One tensor for each value in `other_arguments`.
END
}

View File

@ -0,0 +1,13 @@
op {
graph_op_name: "ThreadPoolDataset"
in_arg {
name: "thread_pool"
description: <<END
A resource produced by the ThreadPoolHandle op.
END
}
summary: <<END
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,35 @@
op {
graph_op_name: "ThreadPoolHandle"
out_arg {
name: "handle"
description: <<END
A resource that can be consumed by one or more ExperimentalThreadPoolDataset
ops.
END
}
attr {
name: "num_threads"
description: <<END
The number of threads in the thread pool.
END
}
attr {
name: "max_intra_op_parallelism"
description: <<END
The maximum degree of parallelism to use within operations that execute on this
threadpool.
END
}
attr {
name: "display_name"
description: <<END
A human-readable name for the threads that may be visible in some
visualizations.
threadpool.
END
}
summary: <<END
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
END
visibility: HIDDEN
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "UnbatchDataset"
visibility: HIDDEN
summary: "A dataset that splits the elements of its input into multiple elements."
}

View File

@ -0,0 +1,8 @@
op {
graph_op_name: "UniqueDataset"
summary: <<END
Creates a dataset that contains the unique elements of `input_dataset`.
END
visibility: HIDDEN
}

View File

@ -1,6 +0,0 @@
op {
graph_op_name: "ExperimentalIdentityIndexedDataset"
endpoint {
name: "data.ExperimentalIdentityIndexedDataset"
}
}

View File

@ -1,6 +0,0 @@
op {
graph_op_name: "ExperimentalIndexedDatasetGet"
endpoint {
name: "data.ExperimentalIndexedDatasetGet"
}
}

View File

@ -1,6 +0,0 @@
op {
graph_op_name: "ExperimentalIndexedDatasetMaterialize"
endpoint {
name: "data.ExperimentalIndexedDatasetMaterialize"
}
}

View File

@ -1,6 +0,0 @@
op {
graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
endpoint {
name: "data.ExperimentalMaterializedIndexDatasetHandle"
}
}

View File

@ -1,6 +0,0 @@
op {
graph_op_name: "ExperimentalNumaMapAndBatchDataset"
endpoint {
name: "data.ExperimentalNumaMapAndBatchDataset"
}
}

View File

@ -56,6 +56,8 @@ class DatasetCardinalityOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
DatasetToGraphOp);
REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
DatasetCardinalityOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
DatasetCardinalityOp);

View File

@ -155,6 +155,8 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
AssertNextDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
AssertNextDatasetOp);

View File

@ -76,6 +76,8 @@ class AutoShardDatasetOp : public UnaryDatasetOpKernel {
}
};
REGISTER_KERNEL_BUILDER(Name("AutoShardDataset").Device(DEVICE_CPU),
AutoShardDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU),
AutoShardDatasetOp);

View File

@ -357,7 +357,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
}; // class ChooseFastestDatasetOp
// Register the kernel implementation for ChooseFastestDataset.
REGISTER_KERNEL_BUILDER(Name("ChooseFastestDataset").Device(DEVICE_CPU),
ChooseFastestDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalChooseFastestDataset").Device(DEVICE_CPU),
ChooseFastestDatasetOp);

View File

@ -854,7 +854,7 @@ class CSVDatasetOp : public DatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
}; // class CSVDatasetOp
// Register the kernel implementation for CSVDataset.
REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
CSVDatasetOp);

View File

@ -312,6 +312,8 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU),
DenseToSparseBatchDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDenseToSparseBatchDataset").Device(DEVICE_CPU),
DenseToSparseBatchDatasetOp);

View File

@ -277,6 +277,8 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);

View File

@ -412,9 +412,13 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
GroupByReducerDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalGroupByReducerDataset").Device(DEVICE_CPU),
GroupByReducerDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("GroupByReducerDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByReducerDataset");
} // namespace

View File

@ -507,9 +507,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU),
GroupByWindowDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalGroupByWindowDataset").Device(DEVICE_CPU),
GroupByWindowDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("GroupByWindowDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByWindowDataset");
} // namespace

View File

@ -140,6 +140,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
IgnoreErrorsDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
IgnoreErrorsDatasetOp);

View File

@ -216,6 +216,7 @@ class LMDBDatasetOp : public DatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
LMDBDatasetOp);

View File

@ -764,9 +764,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
bool preserve_cardinality_;
};
REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU),
MapAndBatchDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalMapAndBatchDataset").Device(DEVICE_CPU),
MapAndBatchDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("MapAndBatchDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalMapAndBatchDataset");
} // namespace

View File

@ -366,6 +366,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("MatchingFilesDataset").Device(DEVICE_CPU),
MatchingFilesDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalMatchingFilesDataset").Device(DEVICE_CPU),
MatchingFilesDatasetOp);

View File

@ -123,6 +123,8 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("NonSerializableDataset").Device(DEVICE_CPU),
NonSerializableDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalNonSerializableDataset").Device(DEVICE_CPU),
NonSerializableDatasetOp);

View File

@ -1069,9 +1069,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
} // namespace

View File

@ -397,6 +397,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::vector<std::size_t> elements_per_stride_;
};
REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
ParseExampleDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalParseExampleDataset").Device(DEVICE_CPU),
ParseExampleDatasetOp);

View File

@ -45,6 +45,8 @@ class IteratorGetDeviceOp : public OpKernel {
}
};
REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
IteratorGetDeviceOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
IteratorGetDeviceOp);

View File

@ -154,6 +154,8 @@ class RandomDatasetOp : public DatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
RandomDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalRandomDataset").Device(DEVICE_CPU),
RandomDatasetOp);

View File

@ -64,6 +64,8 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel {
}
};
REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU),
RebatchDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU),
RebatchDatasetOp);

View File

@ -290,9 +290,11 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
bool preserve_cardinality_;
};
REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalScanDataset").Device(DEVICE_CPU),
ScanDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ScanDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalScanDataset");
} // namespace

View File

@ -211,9 +211,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
SetStatsAggregatorDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalSetStatsAggregatorDataset").Device(DEVICE_CPU),
SetStatsAggregatorDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -133,8 +133,17 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("SleepDataset").Device(DEVICE_CPU),
SleepDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset").Device(DEVICE_CPU),
SleepDatasetOp);
REGISTER_KERNEL_BUILDER(Name("SleepDataset")
.Device(DEVICE_GPU)
.HostMemory("sleep_microseconds")
.HostMemory("input_dataset")
.HostMemory("handle"),
SleepDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset")
.Device(DEVICE_GPU)
.HostMemory("sleep_microseconds")

View File

@ -302,6 +302,8 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("SlidingWindowDataset").Device(DEVICE_CPU),
SlidingWindowDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalSlidingWindowDataset").Device(DEVICE_CPU),
SlidingWindowDatasetOp);

View File

@ -214,6 +214,7 @@ class SqlDatasetOp : public DatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalSqlDataset").Device(DEVICE_CPU),
SqlDatasetOp);

View File

@ -296,14 +296,21 @@ class StatsAggregatorSetSummaryWriterOp : public OpKernel {
}
};
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandle").Device(DEVICE_CPU),
StatsAggregatorHandleOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
StatsAggregatorHandleOp);
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
StatsAggregatorHandleOpV2);
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
StatsAggregatorSummaryOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
StatsAggregatorSummaryOp);
REGISTER_KERNEL_BUILDER(
Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
StatsAggregatorSetSummaryWriterOp);

View File

@ -258,13 +258,18 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(
Name("ExperimentalLatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
BytesProducedStatsDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalBytesProducedStatsDataset").Device(DEVICE_CPU),
BytesProducedStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalLatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -198,8 +198,12 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
};
REGISTER_KERNEL_BUILDER(Name("TakeWhileDataset").Device(DEVICE_CPU),
TakeWhileDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalTakeWhileDataset").Device(DEVICE_CPU),
TakeWhileDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("TakeWhileDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalTakeWhileDataset");
} // namespace

View File

@ -431,14 +431,25 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("MaxIntraOpParallelismDataset").Device(DEVICE_CPU),
MaxIntraOpParallelismDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalMaxIntraOpParallelismDataset").Device(DEVICE_CPU),
MaxIntraOpParallelismDatasetOp);
REGISTER_KERNEL_BUILDER(Name("PrivateThreadPoolDataset").Device(DEVICE_CPU),
PrivateThreadPoolDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalPrivateThreadPoolDataset").Device(DEVICE_CPU),
PrivateThreadPoolDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
ThreadPoolDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
ThreadPoolDatasetOp);

View File

@ -100,6 +100,8 @@ class ToTFRecordOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
ToTFRecordOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDatasetToTFRecord").Device(DEVICE_CPU), ToTFRecordOp);

View File

@ -221,6 +221,8 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU),
UnbatchDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalUnbatchDataset").Device(DEVICE_CPU),
UnbatchDatasetOp);

View File

@ -221,6 +221,8 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
};
};
REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);

View File

@ -25428,18 +25428,6 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalIdentityIndexedDataset"
input_arg {
name: "size"
type: DT_UINT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
is_stateful: true
}
op {
name: "ExperimentalIgnoreErrorsDataset"
input_arg {
@ -25463,46 +25451,6 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalIndexedDatasetGet"
input_arg {
name: "materialized"
type: DT_RESOURCE
}
input_arg {
name: "index"
type: DT_UINT64
}
output_arg {
name: "components"
type_list_attr: "output_types"
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "ExperimentalIndexedDatasetMaterialize"
input_arg {
name: "dataset"
type: DT_VARIANT
}
input_arg {
name: "materialized"
type: DT_RESOURCE
}
is_stateful: true
}
op {
name: "ExperimentalIteratorGetDevice"
input_arg {
@ -25774,34 +25722,6 @@ op {
}
is_stateful: true
}
op {
name: "ExperimentalMaterializedIndexDatasetHandle"
output_arg {
name: "handle"
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
}
attr {
name: "shared_name"
type: "string"
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "ExperimentalMaxIntraOpParallelismDataset"
input_arg {
@ -25852,109 +25772,6 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalNumaMapAndBatchDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "batch_size"
type: DT_INT64
}
input_arg {
name: "num_parallel_calls"
type: DT_INT64
}
input_arg {
name: "drop_remainder"
type: DT_BOOL
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
}
op {
name: "ExperimentalNumaMapAndBatchDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "batch_size"
type: DT_INT64
}
input_arg {
name: "num_parallel_calls"
type: DT_INT64
}
input_arg {
name: "drop_remainder"
type: DT_BOOL
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
attr {
name: "preserve_cardinality"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ExperimentalParallelInterleaveDataset"
input_arg {

View File

@ -23524,18 +23524,6 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalIdentityIndexedDataset"
input_arg {
name: "size"
type: DT_UINT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
is_stateful: true
}
op {
name: "ExperimentalIgnoreErrorsDataset"
input_arg {
@ -23559,46 +23547,6 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalIndexedDatasetGet"
input_arg {
name: "materialized"
type: DT_RESOURCE
}
input_arg {
name: "index"
type: DT_UINT64
}
output_arg {
name: "components"
type_list_attr: "output_types"
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "ExperimentalIndexedDatasetMaterialize"
input_arg {
name: "dataset"
type: DT_VARIANT
}
input_arg {
name: "materialized"
type: DT_RESOURCE
}
is_stateful: true
}
op {
name: "ExperimentalIteratorGetDevice"
input_arg {
@ -23870,34 +23818,6 @@ op {
}
is_stateful: true
}
op {
name: "ExperimentalMaterializedIndexDatasetHandle"
output_arg {
name: "handle"
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
}
attr {
name: "shared_name"
type: "string"
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "ExperimentalMaxIntraOpParallelismDataset"
input_arg {
@ -23948,109 +23868,6 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalNumaMapAndBatchDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "batch_size"
type: DT_INT64
}
input_arg {
name: "num_parallel_calls"
type: DT_INT64
}
input_arg {
name: "drop_remainder"
type: DT_BOOL
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
}
op {
name: "ExperimentalNumaMapAndBatchDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "batch_size"
type: DT_INT64
}
input_arg {
name: "num_parallel_calls"
type: DT_INT64
}
input_arg {
name: "drop_remainder"
type: DT_BOOL
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
attr {
name: "preserve_cardinality"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ExperimentalParallelInterleaveDataset"
input_arg {

View File

@ -17,10 +17,40 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("StatsAggregatorSetSummaryWriter")
.Input("stats_aggregator: resource")
.Input("summary: resource")
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("AssertNextDataset")
.Input("input_dataset: variant")
.Input("transformations: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// transformations should be a vector.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalAssertNextDataset")
.Input("input_dataset: variant")
.Input("transformations: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// transformations should be a vector.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("AutoShardDataset")
.Input("input_dataset: variant")
.Input("num_workers: int64")
.Input("index: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalAutoShardDataset")
.Input("input_dataset: variant")
@ -31,6 +61,18 @@ REGISTER_OP("ExperimentalAutoShardDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BytesProducedStatsDataset")
.Input("input_dataset: variant")
.Input("tag: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle tag_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalBytesProducedStatsDataset")
.Input("input_dataset: variant")
.Input("tag: string")
@ -57,6 +99,66 @@ REGISTER_OP("ChooseFastestBranchDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ChooseFastestDataset")
.Input("input_datasets: N * variant")
.Output("handle: variant")
.Attr("N: int >= 2")
.Attr("num_experiments: int")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalChooseFastestDataset")
.Input("input_datasets: N * variant")
.Output("handle: variant")
.Attr("N: int >= 2")
.Attr("num_experiments: int")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("CSVDataset")
.Input("filenames: string")
.Input("compression_type: string")
.Input("buffer_size: int64")
.Input("header: bool")
.Input("field_delim: string")
.Input("use_quote_delim: bool")
.Input("na_value: string")
.Input("select_cols: int64")
.Input("record_defaults: output_types")
.Output("handle: variant")
.Attr("output_types: list({float,double,int32,int64,string}) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
// `compression_type`, `buffer_size`, `header`, `field_delim`,
// `use_quote_delim`, `na_value` must be scalars
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
// `select_cols` must be a vector
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
// `record_defaults` must be lists of scalars
for (size_t i = 8; i < c->num_inputs(); ++i) {
shape_inference::ShapeHandle v;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
return errors::InvalidArgument(
"Shape of a default must be a length-0 or length-1 vector, or a "
"scalar.");
}
}
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalCSVDataset")
.Input("filenames: string")
.Input("compression_type: string")
@ -99,6 +201,11 @@ REGISTER_OP("ExperimentalCSVDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("DatasetCardinality")
.Input("input_dataset: variant")
.Output("cardinality: int64")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalDatasetCardinality")
.Input("input_dataset: variant")
.Output("cardinality: int64")
@ -108,6 +215,13 @@ REGISTER_OP("ExperimentalDatasetCardinality")
// implement a mechanism to determine whether `dataset` has a side-effect
// and use it to decide whether to use a stateless or stateful version of this
// op.
REGISTER_OP("DatasetToTFRecord")
.Input("input_dataset: variant")
.Input("filename: string")
.Input("compression_type: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("ExperimentalDatasetToTFRecord")
.Input("input_dataset: variant")
.Input("filename: string")
@ -115,6 +229,22 @@ REGISTER_OP("ExperimentalDatasetToTFRecord")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("DenseToSparseBatchDataset")
.Input("input_dataset: variant")
.Input("batch_size: int64")
.Input("row_shape: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// batch_size should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
// row_shape should be a 1-D vector.
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalDenseToSparseBatchDataset")
.Input("input_dataset: variant")
.Input("batch_size: int64")
@ -131,6 +261,15 @@ REGISTER_OP("ExperimentalDenseToSparseBatchDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("DirectedInterleaveDataset")
.Input("selector_input_dataset: variant")
.Input("data_input_datasets: N * variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("N: int >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalDirectedInterleaveDataset")
.Input("selector_input_dataset: variant")
.Input("data_input_datasets: N * variant")
@ -140,6 +279,26 @@ REGISTER_OP("ExperimentalDirectedInterleaveDataset")
.Attr("N: int >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("GroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
.Input("init_func_other_arguments: Tinit_func_other_arguments")
.Input("reduce_func_other_arguments: Treduce_func_other_arguments")
.Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
.Output("handle: variant")
.Attr("key_func: func")
.Attr("init_func: func")
.Attr("reduce_func: func")
.Attr("finalize_func: func")
.Attr("Tkey_func_other_arguments: list(type) >= 0")
.Attr("Tinit_func_other_arguments: list(type) >= 0")
.Attr("Treduce_func_other_arguments: list(type) >= 0")
.Attr("Tfinalize_func_other_arguments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalGroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@ -160,6 +319,23 @@ REGISTER_OP("ExperimentalGroupByReducerDataset")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("GroupByWindowDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
.Input("reduce_func_other_arguments: Treduce_func_other_arguments")
.Input(
"window_size_func_other_arguments: Twindow_size_func_other_arguments")
.Output("handle: variant")
.Attr("key_func: func")
.Attr("reduce_func: func")
.Attr("window_size_func: func")
.Attr("Tkey_func_other_arguments: list(type) >= 0")
.Attr("Treduce_func_other_arguments: list(type) >= 0")
.Attr("Twindow_size_func_other_arguments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalGroupByWindowDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@ -177,6 +353,13 @@ REGISTER_OP("ExperimentalGroupByWindowDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("IgnoreErrorsDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalIgnoreErrorsDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
@ -184,6 +367,28 @@ REGISTER_OP("ExperimentalIgnoreErrorsDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("IteratorGetDevice")
.Input("resource: resource")
.Output("device: string")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalIteratorGetDevice")
.Input("resource: resource")
.Output("device: string")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("LatencyStatsDataset")
.Input("input_dataset: variant")
.Input("tag: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle tag_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalLatencyStatsDataset")
.Input("input_dataset: variant")
.Input("tag: string")
@ -196,6 +401,51 @@ REGISTER_OP("ExperimentalLatencyStatsDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("LMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalLMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapAndBatchDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("batch_size: int64")
.Input("num_parallel_calls: int64")
.Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("f: func")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("preserve_cardinality: bool = false")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Use index from the end to retrieve the Input shapes,
// so that to avoid guessing the length of "other_arguments".
// batch_size, num_parallel_calls, and drop_remainder are 0-D scalars.
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(
c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalMapAndBatchDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@ -223,14 +473,6 @@ REGISTER_OP("ExperimentalMapAndBatchDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalRebatchDataset")
.Input("input_dataset: variant")
.Input("num_workers: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalMapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@ -243,6 +485,18 @@ REGISTER_OP("ExperimentalMapDataset")
.Attr("preserve_cardinality: bool = false")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MatchingFilesDataset")
.Input("patterns: string")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `patterns` must be a scalar or a vector.
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalMatchingFilesDataset")
.Input("patterns: string")
.Output("handle: variant")
@ -255,6 +509,29 @@ REGISTER_OP("ExperimentalMatchingFilesDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("MaxIntraOpParallelismDataset")
.Input("input_dataset: variant")
.Input("max_intra_op_parallelism: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
.Input("input_dataset: variant")
.Input("max_intra_op_parallelism: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("NonSerializableDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalNonSerializableDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
@ -262,6 +539,21 @@ REGISTER_OP("ExperimentalNonSerializableDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ParallelInterleaveDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("cycle_length: int64")
.Input("block_length: int64")
.Input("sloppy: bool")
.Input("buffer_output_elements: int64")
.Input("prefetch_input_elements: int64")
.Output("handle: variant")
.Attr("f: func")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalParallelInterleaveDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@ -277,6 +569,23 @@ REGISTER_OP("ExperimentalParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ParseExampleDataset")
.Input("input_dataset: variant")
.Input("num_parallel_calls: int64")
.Input("dense_defaults: Tdense")
.Output("handle: variant")
.Attr("sparse_keys: list(string) >= 0")
.Attr("dense_keys: list(string) >= 0")
.Attr("sparse_types: list({float,int64,string}) >= 0")
.Attr("Tdense: list({float,int64,string}) >= 0")
.Attr("dense_shapes: list(shape) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1") // Output components will be
// sorted by key (dense_keys and
// sparse_keys combined) here.
.Attr("sloppy: bool = false")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalParseExampleDataset")
.Input("input_dataset: variant")
.Input("num_parallel_calls: int64")
@ -294,6 +603,22 @@ REGISTER_OP("ExperimentalParseExampleDataset")
.Attr("sloppy: bool = false")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("PrivateThreadPoolDataset")
.Input("input_dataset: variant")
.Input("num_threads: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
.Input("input_dataset: variant")
.Input("num_threads: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalRandomDataset")
.Input("seed: int64")
.Input("seed2: int64")
@ -310,6 +635,68 @@ REGISTER_OP("ExperimentalRandomDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("RandomDataset")
.Input("seed: int64")
.Input("seed2: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// buffer_size, seed, and seed2 should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalRebatchDataset")
.Input("input_dataset: variant")
.Input("num_workers: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("RebatchDataset")
.Input("input_dataset: variant")
.Input("num_workers: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("SamplingDataset")
.Input("input_dataset: variant")
.Input("rate: float32")
.Input("seed: int64")
.Input("seed2: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// rate, seed, and seed2 should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ScanDataset")
.Input("input_dataset: variant")
.Input("initial_state: Tstate")
.Input("other_arguments: Targuments")
.Output("handle: variant")
.Attr("f: func")
.Attr("Tstate: list(type) >= 1")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("preserve_cardinality: bool = false")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalScanDataset")
.Input("input_dataset: variant")
.Input("initial_state: Tstate")
@ -323,6 +710,16 @@ REGISTER_OP("ExperimentalScanDataset")
.Attr("preserve_cardinality: bool = false")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
.Input("tag: string")
.Input("counter_prefix: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalSetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
@ -333,6 +730,20 @@ REGISTER_OP("ExperimentalSetStatsAggregatorDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("SleepDataset")
.Input("input_dataset: variant")
.Input("sleep_microseconds: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// Both inputs are scalar.
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalSleepDataset")
.Input("input_dataset: variant")
.Input("sleep_microseconds: int64")
@ -347,6 +758,23 @@ REGISTER_OP("ExperimentalSleepDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("SlidingWindowDataset")
.Input("input_dataset: variant")
.Input("window_size: int64")
.Input("window_shift: int64")
.Input("window_stride: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// window_size, window_shift, and window_stride should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalSlidingWindowDataset")
.Input("input_dataset: variant")
.Input("window_size: int64")
@ -382,6 +810,24 @@ REGISTER_OP("SnapshotDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("SqlDataset")
.Input("driver_name: string")
.Input("data_source_name: string")
.Input("query: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// driver_name, data_source_name, and query should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalSqlDataset")
.Input("driver_name: string")
.Input("data_source_name: string")
@ -400,6 +846,12 @@ REGISTER_OP("ExperimentalSqlDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("StatsAggregatorHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("container: string = ''")
.Attr("shared_name: string = ''");
REGISTER_OP("ExperimentalStatsAggregatorHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
@ -412,11 +864,31 @@ REGISTER_OP("StatsAggregatorHandleV2")
.Attr("container: string = ''")
.Attr("shared_name: string = ''");
REGISTER_OP("StatsAggregatorSetSummaryWriter")
.Input("stats_aggregator: resource")
.Input("summary: resource")
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("StatsAggregatorSummary")
.Input("iterator: resource")
.Output("summary: string")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalStatsAggregatorSummary")
.Input("iterator: resource")
.Output("summary: string")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("TakeWhileDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Output("handle: variant")
.Attr("predicate: func")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalTakeWhileDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@ -427,36 +899,9 @@ REGISTER_OP("ExperimentalTakeWhileDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalUnbatchDataset")
REGISTER_OP("ThreadPoolDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalUniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalIteratorGetDevice")
.Input("resource: resource")
.Output("device: string")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
.Input("input_dataset: variant")
.Input("max_intra_op_parallelism: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
.Input("input_dataset: variant")
.Input("num_threads: int64")
.Input("thread_pool: resource")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
@ -470,6 +915,15 @@ REGISTER_OP("ExperimentalThreadPoolDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ThreadPoolHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("num_threads: int")
.Attr("max_intra_op_parallelism: int = 1")
.Attr("display_name: string")
.Attr("container: string = ''")
.Attr("shared_name: string = ''");
REGISTER_OP("ExperimentalThreadPoolHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
@ -479,136 +933,32 @@ REGISTER_OP("ExperimentalThreadPoolHandle")
.Attr("container: string = ''")
.Attr("shared_name: string = ''");
REGISTER_OP("ExperimentalAssertNextDataset")
REGISTER_OP("UnbatchDataset")
.Input("input_dataset: variant")
.Input("transformations: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// transformations should be a vector.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalNumaMapAndBatchDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("batch_size: int64")
.Input("num_parallel_calls: int64")
.Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("f: func")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("preserve_cardinality: bool = false")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Use index from the end to retrieve the Input shapes,
// so that to avoid guessing the length of "other_arguments".
// batch_size, num_parallel_batches, and drop_remainder are 0-D scalars.
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(
c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ExperimentalLMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalChooseFastestDataset")
.Input("input_datasets: N * variant")
.Output("handle: variant")
.Attr("N: int >= 2")
.Attr("num_experiments: int")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalIdentityIndexedDataset")
.Input("size: uint64")
.Output("handle: variant")
.SetIsStateful()
.SetShapeFn(
shape_inference::ScalarShape); // TODO(saeta): check input shapes.
REGISTER_OP("SamplingDataset")
REGISTER_OP("ExperimentalUnbatchDataset")
.Input("input_dataset: variant")
.Input("rate: float32")
.Input("seed: int64")
.Input("seed2: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// rate, seed, and seed2 should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
return shape_inference::ScalarShape(c);
});
///////////////////////////////////////////////////////////////////////////////
// IndexedDataset Internals
///////////////////////////////////////////////////////////////////////////////
// Creates the handle.
REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
.Output("handle: resource")
.Attr("container: string")
.Attr("shared_name: string")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
// Actually materialize the materialize handle.
REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
.Input("dataset: variant")
.Input("materialized: resource")
.SetShapeFn(shape_inference::NoOutputs);
namespace {
Status GetShapeFn(shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
output_shapes.size(), " vs. ", c->num_outputs());
}
for (size_t i = 0; i < output_shapes.size(); ++i) {
shape_inference::ShapeHandle output_shape_handle;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
output_shapes[i], &output_shape_handle));
c->set_output(static_cast<int>(i), output_shape_handle);
}
return Status::OK();
}
} // namespace
REGISTER_OP("ExperimentalIndexedDatasetGet")
.Input("materialized: resource")
.Input("index: uint64")
.Output("components: output_types")
REGISTER_OP("UniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(GetShapeFn);
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalUniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
@ -246,11 +247,18 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
dataset_ops.get_legacy_output_types(input_dataset),
tensor_shape.vector(None).concatenate(self._row_shape))
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._batch_size,
row_shape=convert.partial_shape_to_tensor(self._row_shape),
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._batch_size,
row_shape=convert.partial_shape_to_tensor(self._row_shape),
**self._flat_structure)
else:
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._batch_size,
row_shape=convert.partial_shape_to_tensor(self._row_shape),
**self._flat_structure)
super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
variant_tensor)
@ -294,15 +302,26 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
lambda component_spec: component_spec._batch(None),
self._map_func.output_structure)
# pylint: enable=protected-access
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
batch_size=self._batch_size_t,
num_parallel_calls=self._num_parallel_calls_t,
drop_remainder=self._drop_remainder_t,
preserve_cardinality=True,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.map_and_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
batch_size=self._batch_size_t,
num_parallel_calls=self._num_parallel_calls_t,
drop_remainder=self._drop_remainder_t,
preserve_cardinality=True,
**self._flat_structure)
else:
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
batch_size=self._batch_size_t,
num_parallel_calls=self._num_parallel_calls_t,
drop_remainder=self._drop_remainder_t,
preserve_cardinality=True,
**self._flat_structure)
super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util.tf_export import tf_export
@ -47,4 +48,8 @@ def cardinality(dataset):
the cardinality is infinite or unknown, the operation returns the named
constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
"""
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
if compat.forward_compatible(2019, 8, 3):
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
else:
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
@ -47,11 +48,18 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
self._input_dataset = input_dataset
self._element_spec = input_dataset.element_spec
variant_tensor = ged_ops.experimental_auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
index=index,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
index=index,
**self._flat_structure)
else:
variant_tensor = ged_ops.experimental_auto_shard_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
index=index,
**self._flat_structure)
super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
@property
@ -87,10 +95,16 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
self._element_spec = structure.convert_legacy_structure(
input_types, output_shapes, input_classes)
variant_tensor = ged_ops.experimental_rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
**self._flat_structure)
else:
variant_tensor = ged_ops.experimental_rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
**self._flat_structure)
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
@property

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.tf_export import tf_export
@ -59,8 +60,14 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
self._input_dataset = input_dataset
variant_tensor = (
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
if compat.forward_compatible(2019, 8, 3):
variant_tensor = (
gen_experimental_dataset_ops.ignore_errors_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
@ -253,17 +254,30 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
self._make_init_func(reducer.init_func)
self._make_reduce_func(reducer.reduce_func, input_dataset)
self._make_finalize_func(reducer.finalize_func)
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._init_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._finalize_func.function.captured_inputs,
key_func=self._key_func.function,
init_func=self._init_func.function,
reduce_func=self._reduce_func.function,
finalize_func=self._finalize_func.function,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._init_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._finalize_func.function.captured_inputs,
key_func=self._key_func.function,
init_func=self._init_func.function,
reduce_func=self._reduce_func.function,
finalize_func=self._finalize_func.function,
**self._flat_structure)
else:
variant_tensor = ged_ops.group_by_reducer_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._init_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._finalize_func.function.captured_inputs,
key_func=self._key_func.function,
init_func=self._init_func.function,
reduce_func=self._reduce_func.function,
finalize_func=self._finalize_func.function,
**self._flat_structure)
super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
def _make_key_func(self, key_func, input_dataset):
@ -375,15 +389,26 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
self._make_key_func(key_func, input_dataset)
self._make_reduce_func(reduce_func, input_dataset)
self._make_window_size_func(window_size_func)
variant_tensor = ged_ops.experimental_group_by_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._window_size_func.function.captured_inputs,
key_func=self._key_func.function,
reduce_func=self._reduce_func.function,
window_size_func=self._window_size_func.function,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.group_by_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._window_size_func.function.captured_inputs,
key_func=self._key_func.function,
reduce_func=self._reduce_func.function,
window_size_func=self._window_size_func.function,
**self._flat_structure)
else:
variant_tensor = ged_ops.experimental_group_by_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._window_size_func.function.captured_inputs,
key_func=self._key_func.function,
reduce_func=self._reduce_func.function,
window_size_func=self._window_size_func.function,
**self._flat_structure)
super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
def _make_window_size_func(self, window_size_func):

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@ -125,11 +126,18 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# pylint: disable=protected-access
return (
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**self._flat_structure))
if compat.forward_compatible(2019, 8, 3):
return (
gen_experimental_dataset_ops.directed_interleave_dataset(
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**self._flat_structure))
else:
return (
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**self._flat_structure))
# pylint: enable=protected-access
def _inputs(self):

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
@ -31,7 +32,11 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
def __init__(self, patterns):
self._patterns = ops.convert_to_tensor(
patterns, dtype=dtypes.string, name="patterns")
variant_tensor = ged_ops.experimental_matching_files_dataset(self._patterns)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.matching_files_dataset(self._patterns)
else:
variant_tensor = ged_ops.experimental_matching_files_dataset(
self._patterns)
super(MatchingFilesDataset, self).__init__(variant_tensor)
@property

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -104,11 +105,18 @@ class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
raise ValueError("At least one transformation should be specified")
self._transformations = ops.convert_to_tensor(
transformations, dtype=dtypes.string, name="transformations")
variant_tensor = (
gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations,
**self._flat_structure))
if compat.forward_compatible(2019, 8, 3):
variant_tensor = (
gen_experimental_dataset_ops.assert_next_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations,
**self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations,
**self._flat_structure))
super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
@ -118,10 +126,16 @@ class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
"""See `non_serializable()` for details."""
self._input_dataset = input_dataset
variant_tensor = (
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
if compat.forward_compatible(2019, 8, 3):
variant_tensor = (
gen_experimental_dataset_ops.non_serializable_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure))
super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
@ -157,11 +171,18 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
"""
self._datasets = list(datasets)
self._element_spec = self._datasets[0].element_spec
variant_tensor = (
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
num_experiments=num_experiments,
**self._flat_structure))
if compat.forward_compatible(2019, 8, 3):
variant_tensor = (
gen_experimental_dataset_ops.choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
num_experiments=num_experiments,
**self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
[dataset._variant_tensor for dataset in self._datasets], # pylint: disable=protected-access
num_experiments=num_experiments,
**self._flat_structure))
super(_ChooseFastestDataset, self).__init__(variant_tensor)
def _inputs(self):

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
@ -79,16 +80,28 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
variant_tensor = (
gen_experimental_dataset_ops.experimental_parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
**self._flat_structure))
if compat.forward_compatible(2019, 8, 3):
variant_tensor = (
gen_experimental_dataset_ops.parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
**self._flat_structure))
else:
variant_tensor = (
gen_experimental_dataset_ops.experimental_parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
**self._flat_structure))
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
@property

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import functools
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import structure
@ -34,8 +35,12 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
def __init__(self, seed=None):
"""A `Dataset` of pseudorandom values."""
self._seed, self._seed2 = random_seed.get_seed(seed)
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
seed=self._seed, seed2=self._seed2, **self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.random_dataset(
seed=self._seed, seed2=self._seed2, **self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
seed=self._seed, seed2=self._seed2, **self._flat_structure)
super(RandomDatasetV2, self).__init__(variant_tensor)
@property

View File

@ -23,6 +23,7 @@ import functools
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import interleave_ops
@ -664,17 +665,30 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
)
self._element_spec = tuple(
structure.TensorStructure(d.dtype, []) for d in self._record_defaults)
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
header=self._header,
output_shapes=self._flat_shapes,
field_delim=self._field_delim,
use_quote_delim=self._use_quote_delim,
na_value=self._na_value,
select_cols=self._select_cols,
compression_type=self._compression_type)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
header=self._header,
output_shapes=self._flat_shapes,
field_delim=self._field_delim,
use_quote_delim=self._use_quote_delim,
na_value=self._na_value,
select_cols=self._select_cols,
compression_type=self._compression_type)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
header=self._header,
output_shapes=self._flat_shapes,
field_delim=self._field_delim,
use_quote_delim=self._use_quote_delim,
na_value=self._na_value,
select_cols=self._select_cols,
compression_type=self._compression_type)
super(CsvDatasetV2, self).__init__(variant_tensor)
@property
@ -957,9 +971,14 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
query, dtype=dtypes.string, name="query")
self._element_spec = nest.map_structure(
lambda dtype: structure.TensorStructure(dtype, []), output_types)
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
self._driver_name, self._data_source_name, self._query,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.sql_dataset(
self._driver_name, self._data_source_name, self._query,
**self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
self._driver_name, self._data_source_name, self._query,
**self._flat_structure)
super(SqlDatasetV2, self).__init__(variant_tensor)
@property

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import collections
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
@ -122,13 +123,22 @@ class _ScanDataset(dataset_ops.UnaryDataset):
self._scan_func = wrapped_func
self._scan_func.function.add_to_graph(ops.get_default_graph())
# pylint: disable=protected-access
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
**self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
**self._flat_structure)
super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
@ -27,10 +28,16 @@ class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, sleep_microseconds):
self._input_dataset = input_dataset
self._sleep_microseconds = sleep_microseconds
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds,
**self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds,
**self._flat_structure)
super(_SleepDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import tempfile
from tensorflow.python.compat import compat
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.util.tf_export import tf_export
@ -125,7 +126,10 @@ class StatsAggregatorV1(object):
def __init__(self):
"""Creates a `StatsAggregator`."""
self._resource = ged_ops.experimental_stats_aggregator_handle()
if compat.forward_compatible(2019, 8, 3):
self._resource = ged_ops.stats_aggregator_handle()
else:
self._resource = ged_ops.experimental_stats_aggregator_handle()
def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
@ -137,7 +141,10 @@ class StatsAggregatorV1(object):
Returns:
A scalar string `tf.Tensor` that summarizes the aggregated statistics.
"""
return ged_ops.experimental_stats_aggregator_summary(self._resource)
if compat.forward_compatible(2019, 8, 3):
return ged_ops.stats_aggregator_summary(self._resource)
else:
return ged_ops.experimental_stats_aggregator_summary(self._resource)
# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -65,10 +66,14 @@ def bytes_produced_stats(tag):
"""
def _apply_fn(dataset):
return _StatsDataset(
dataset,
gen_experimental_dataset_ops.experimental_bytes_produced_stats_dataset,
tag)
if compat.forward_compatible(2019, 8, 3):
return _StatsDataset(
dataset, gen_experimental_dataset_ops.bytes_produced_stats_dataset,
tag)
else:
return _StatsDataset(
dataset, gen_experimental_dataset_ops
.experimental_bytes_produced_stats_dataset, tag)
return _apply_fn
@ -90,9 +95,14 @@ def latency_stats(tag):
"""
def _apply_fn(dataset):
return _StatsDataset(
dataset,
gen_experimental_dataset_ops.experimental_latency_stats_dataset, tag)
if compat.forward_compatible(2019, 8, 3):
return _StatsDataset(
dataset,
gen_experimental_dataset_ops.latency_stats_dataset, tag)
else:
return _StatsDataset(
dataset,
gen_experimental_dataset_ops.experimental_latency_stats_dataset, tag)
return _apply_fn

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure as structure_lib
from tensorflow.python.framework import dtypes
@ -41,11 +42,18 @@ class _TakeWhileDataset(dataset_ops.UnaryUnchangedStructureDataset):
raise ValueError("`predicate` must return a scalar boolean tensor.")
self._predicate = wrapped_func
var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
other_arguments=self._predicate.function.captured_inputs,
predicate=self._predicate.function,
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
var_tensor = gen_experimental_dataset_ops.take_while_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
other_arguments=self._predicate.function.captured_inputs,
predicate=self._predicate.function,
**self._flat_structure)
else:
var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
other_arguments=self._predicate.function.captured_inputs,
predicate=self._predicate.function,
**self._flat_structure)
super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
def _functions(self):

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import threading
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
@ -46,18 +47,31 @@ class PrivateThreadPool(object):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
shared_name=shared_name)
if compat.forward_compatible(2019, 8, 3):
self._resource = ged_ops.thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
shared_name=shared_name)
else:
self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
shared_name=shared_name)
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name)
else:
self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
if compat.forward_compatible(2019, 8, 3):
self._resource = ged_ops.thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
else:
self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
@ -66,10 +80,16 @@ class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, thread_pool):
self._input_dataset = input_dataset
self._thread_pool = thread_pool
variant_tensor = ged_ops.experimental_thread_pool_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = ged_ops.thread_pool_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**self._flat_structure)
else:
variant_tensor = ged_ops.experimental_thread_pool_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**self._flat_structure)
super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_experimental_dataset_ops
@ -59,7 +60,12 @@ class _UniqueDataset(dataset_ops.UnaryUnchangedStructureDataset):
raise TypeError(
"`tf.data.experimental.unique()` only supports inputs with a single "
"`tf.int32`, `tf.int64`, or `tf.string` component.")
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure)
if compat.forward_compatible(2019, 8, 3):
variant_tensor = gen_experimental_dataset_ops.unique_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure)
else:
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure)
super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.data.util import structure
@ -83,5 +84,9 @@ class TFRecordWriter(object):
"produces shape {0} and types {1}".format(
dataset_ops.get_legacy_output_shapes(dataset),
dataset_ops.get_legacy_output_types(dataset)))
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
if compat.forward_compatible(2019, 8, 3):
return gen_experimental_dataset_ops.dataset_to_tf_record(
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
else:
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access

Some files were not shown because too many files have changed in this diff Show More