Create dataset kernels as we go i.e. in the __init__ method of the Dataset
class. This removes the _as_variant_tensor() method from the DatasetV2 class (the version going to be used in TF 2.0) and replaces it with a _variant_tensor property that returns the variant_tensor representing the dataset. Also the __init__() method of DatasetV2 now takes a variant_tensor input. For the DatasetV1 class (current API), we run the _as_variant_tensor() method in the __init__() method, so classes subclassing DatasetV1 should make their super() calls in the end. Another implication is for Estimator code. The estimator input_fn's are supposed to be self contained and can't have ops from other graphs (like default graphs) in them. Earlier on because we didn't add anything to the graph while creating the Dataset object, this wasn't an issue but now this is a problem and the dataset creation code now needs to move into the input_fns. A few other changes were required to make this happen 1. The make_one_shot_iterator code captures inputs by value and since now inputs to a dataset could be other datasets which are not capturable, we use the whitelisting mechanism in functions to recreate these ops. 2. The distribution strategies multi-worker code relied on dataset kernel re-creation on different devices while we created the iterator. In the new world, with the kernels already created, we now have to "clone" the dataset on different devices. 3. Auto sharding in distribution strategies is broken with this CL. For now, this CL disables it, but we can subsequently fix it using some of the cloning logic done for 2). 4. AsGraphDefInternal for functions that capture inputs that are datasets now need to be handled differently as DT_VARIANT tensors representing datasets are not serializable. PiperOrigin-RevId: 226115500
This commit is contained in:
parent
c2dfcd9cea
commit
9a63f5b843
@ -489,7 +489,7 @@ class BigtableTable(object):
|
||||
"len(dataset.output_types))")
|
||||
return gen_bigtable_ops.dataset_to_bigtable(
|
||||
self._resource,
|
||||
dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
dataset._variant_tensor, # pylint: disable=protected-access
|
||||
column_families,
|
||||
columns,
|
||||
timestamp)
|
||||
@ -582,13 +582,14 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource):
|
||||
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
|
||||
"""
|
||||
|
||||
def __init__(self, table):
|
||||
def __init__(self, table, variant_tensor):
|
||||
"""Constructs a _BigtableKeyDataset.
|
||||
|
||||
Args:
|
||||
table: a Bigtable class.
|
||||
variant_tensor: DT_VARIANT representation of the dataset.
|
||||
"""
|
||||
super(_BigtableKeyDataset, self).__init__()
|
||||
super(_BigtableKeyDataset, self).__init__(variant_tensor)
|
||||
self._table = table
|
||||
|
||||
@property
|
||||
@ -601,13 +602,11 @@ class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
|
||||
"""
|
||||
|
||||
def __init__(self, table, prefix):
|
||||
super(_BigtablePrefixKeyDataset, self).__init__(table)
|
||||
self._prefix = prefix
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_bigtable_ops.bigtable_prefix_key_dataset(
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset(
|
||||
table=table._resource, # pylint: disable=protected-access
|
||||
prefix=self._prefix)
|
||||
super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor)
|
||||
|
||||
|
||||
class _BigtableRangeKeyDataset(_BigtableKeyDataset):
|
||||
@ -615,15 +614,13 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset):
|
||||
"""
|
||||
|
||||
def __init__(self, table, start, end):
|
||||
super(_BigtableRangeKeyDataset, self).__init__(table)
|
||||
self._start = start
|
||||
self._end = end
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_bigtable_ops.bigtable_range_key_dataset(
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset(
|
||||
table=table._resource, # pylint: disable=protected-access
|
||||
start_key=self._start,
|
||||
end_key=self._end)
|
||||
super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor)
|
||||
|
||||
|
||||
class _BigtableSampleKeysDataset(_BigtableKeyDataset):
|
||||
@ -633,11 +630,9 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset):
|
||||
# TODO(saeta): Expose the data size offsets into the keys.
|
||||
|
||||
def __init__(self, table):
|
||||
super(_BigtableSampleKeysDataset, self).__init__(table)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_bigtable_ops.bigtable_sample_keys_dataset(
|
||||
table=self._table._resource) # pylint: disable=protected-access
|
||||
variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset(
|
||||
table=table._resource) # pylint: disable=protected-access
|
||||
super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor)
|
||||
|
||||
|
||||
class _BigtableLookupDataset(dataset_ops.DatasetSource):
|
||||
@ -651,20 +646,18 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource):
|
||||
self._normalized = normalized
|
||||
self._column_families = [i[0] for i in normalized]
|
||||
self._columns = [i[1] for i in normalized]
|
||||
variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset(
|
||||
keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
column_families=self._column_families,
|
||||
columns=self._columns)
|
||||
super(_BigtableLookupDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return structure.NestedStructure(tuple(
|
||||
[structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return gen_bigtable_ops.bigtable_lookup_dataset(
|
||||
keys_dataset=self._dataset._as_variant_tensor(),
|
||||
table=self._table._resource,
|
||||
column_families=self._column_families,
|
||||
columns=self._columns)
|
||||
|
||||
|
||||
class _BigtableScanDataset(dataset_ops.DatasetSource):
|
||||
"""_BigtableScanDataset represents a dataset that retrieves keys and values.
|
||||
@ -679,14 +672,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
|
||||
self._columns = [i[1] for i in normalized]
|
||||
self._probability = probability
|
||||
self._num_outputs = len(normalized) + 1 # 1 for row key
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return structure.NestedStructure(tuple(
|
||||
[structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_bigtable_ops.bigtable_scan_dataset(
|
||||
variant_tensor = gen_bigtable_ops.bigtable_scan_dataset(
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
prefix=self._prefix,
|
||||
start_key=self._start,
|
||||
@ -694,6 +680,13 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
|
||||
column_families=self._column_families,
|
||||
columns=self._columns,
|
||||
probability=self._probability)
|
||||
super(_BigtableScanDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return structure.NestedStructure(
|
||||
tuple(
|
||||
[structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
|
||||
|
||||
|
||||
class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
|
||||
@ -705,17 +698,15 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
|
||||
self._prefix = prefix
|
||||
self._start = start
|
||||
self._end = end
|
||||
variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
prefix=self._prefix,
|
||||
start_key=self._start,
|
||||
end_key=self._end)
|
||||
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return structure.NestedStructure(
|
||||
(structure.TensorStructure(dtypes.string, []),
|
||||
structure.TensorStructure(dtypes.string, [])))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
|
||||
table=self._table._resource,
|
||||
prefix=self._prefix,
|
||||
start_key=self._start,
|
||||
end_key=self._end)
|
||||
|
@ -389,13 +389,11 @@ class LMDBDataset(dataset_ops.DatasetSource):
|
||||
Args:
|
||||
filenames: A `tf.string` tensor containing one or more filenames.
|
||||
"""
|
||||
super(LMDBDataset, self).__init__()
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_lmdb_dataset(
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset(
|
||||
self._filenames, **dataset_ops.flat_structure(self))
|
||||
super(LMDBDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -30,7 +30,6 @@ class _SlideDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, window_size, window_shift, window_stride):
|
||||
"""See `sliding_window_batch` for details."""
|
||||
super(_SlideDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._window_size = ops.convert_to_tensor(
|
||||
window_size, dtype=dtypes.int64, name="window_stride")
|
||||
@ -43,14 +42,13 @@ class _SlideDataset(dataset_ops.UnaryDataset):
|
||||
input_dataset.output_types, input_dataset.output_shapes,
|
||||
input_dataset.output_classes)
|
||||
self._structure = input_structure._batch(None) # pylint: disable=protected-access
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_sliding_window_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_sliding_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
window_size=self._window_size,
|
||||
window_shift=self._window_shift,
|
||||
window_stride=self._window_stride,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -472,11 +472,11 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
|
||||
for r in range(len(devices))])
|
||||
|
||||
def _test_dataset(self, dataset_fn, worker_devices, devices,
|
||||
expected_values, auto_shard=True):
|
||||
expected_values):
|
||||
device_map = values.ReplicaDeviceMap(devices)
|
||||
input_workers = values.InputWorkers(device_map, worker_devices)
|
||||
multi_worker_dataset = values.MultiWorkerDataset(
|
||||
dataset_fn, input_workers, auto_shard=auto_shard)
|
||||
dataset_fn, input_workers)
|
||||
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(multi_worker_iterator.initializer)
|
||||
@ -518,16 +518,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
|
||||
worker_devices, devices = self._cpu_devices()
|
||||
with context.graph_mode():
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
||||
self._test_dataset(dataset_fn, worker_devices, devices,
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7]])
|
||||
|
||||
def testDataDistributionNoAutoShard(self):
|
||||
worker_devices, devices = self._cpu_devices()
|
||||
with context.graph_mode():
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(4)
|
||||
self._test_dataset(dataset_fn, worker_devices, devices,
|
||||
[[0, 0], [1, 1], [2, 2], [3, 3]],
|
||||
auto_shard=False)
|
||||
self._test_dataset(
|
||||
dataset_fn, worker_devices, devices,
|
||||
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
||||
|
||||
def testDataDistributionTwoDevicePerWorker(self):
|
||||
if context.num_gpus() < 1:
|
||||
@ -535,8 +528,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
|
||||
worker_devices, devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode():
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(8)
|
||||
self._test_dataset(dataset_fn, worker_devices, devices,
|
||||
[[0, 2, 1, 3], [4, 6, 5, 7]])
|
||||
self._test_dataset(
|
||||
dataset_fn, worker_devices, devices,
|
||||
[[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]])
|
||||
|
||||
def testTupleDataset(self):
|
||||
worker_devices, devices = self._cpu_devices()
|
||||
@ -548,9 +542,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
|
||||
dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
|
||||
expected_values = [
|
||||
[(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2)
|
||||
]
|
||||
expected_values = [[(i, i**2), (i, i**2)] for i in range(8)]
|
||||
self._test_dataset(dataset_fn, worker_devices, devices,
|
||||
expected_values)
|
||||
|
||||
@ -561,17 +553,19 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
|
||||
device_map = values.ReplicaDeviceMap(devices)
|
||||
input_workers = values.InputWorkers(device_map, worker_devices)
|
||||
multi_worker_dataset = values.MultiWorkerDataset(
|
||||
dataset_fn, input_workers, auto_shard=True)
|
||||
dataset_fn, input_workers)
|
||||
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
|
||||
|
||||
sess.run(multi_worker_iterator.initializer)
|
||||
self._test_iterator(sess, multi_worker_iterator, devices,
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7]])
|
||||
self._test_iterator(
|
||||
sess, multi_worker_iterator, devices,
|
||||
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
sess.run(multi_worker_iterator.initializer)
|
||||
self._test_iterator(sess, multi_worker_iterator, devices,
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7]])
|
||||
self._test_iterator(
|
||||
sess, multi_worker_iterator, devices,
|
||||
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
|
||||
|
||||
def testValueErrorForIterator(self):
|
||||
# Incompatiable arguments.
|
||||
|
@ -55,13 +55,11 @@ class SequenceFileDataset(dataset_ops.DatasetSource):
|
||||
Args:
|
||||
filenames: A `tf.string` tensor containing one or more filenames.
|
||||
"""
|
||||
super(SequenceFileDataset, self).__init__()
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.sequence_file_dataset(
|
||||
variant_tensor = gen_dataset_ops.sequence_file_dataset(
|
||||
self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access
|
||||
super(SequenceFileDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.platform import test
|
||||
@ -55,6 +56,7 @@ class DatasetsTest(test.TestCase):
|
||||
session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def)
|
||||
|
||||
self._sess = session.Session(self._worker.target, config=session_config)
|
||||
self._worker_device = '/job:' + worker_job.name
|
||||
|
||||
def testTextLineDataset(self):
|
||||
all_contents = []
|
||||
@ -70,7 +72,8 @@ class DatasetsTest(test.TestCase):
|
||||
dataset = datasets.StreamingFilesDataset(
|
||||
os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text')
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
with ops.device(self._worker_device):
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
self._sess.run(iterator.initializer)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -94,7 +97,8 @@ class DatasetsTest(test.TestCase):
|
||||
dataset = datasets.StreamingFilesDataset(
|
||||
os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord')
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
with ops.device(self._worker_device):
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
self._sess.run(iterator.initializer)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -121,7 +125,8 @@ class DatasetsTest(test.TestCase):
|
||||
|
||||
dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord')
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
with ops.device(self._worker_device):
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
self._sess.run(iterator.initializer)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -154,7 +159,8 @@ class DatasetsTest(test.TestCase):
|
||||
os.path.join(self.get_temp_dir(), 'fixed_length*'),
|
||||
filetype=FixedLengthFile)
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
with ops.device(self._worker_device):
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
self._sess.run(iterator.initializer)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -177,7 +183,8 @@ class DatasetsTest(test.TestCase):
|
||||
dataset = datasets.StreamingFilesDataset(
|
||||
dataset_ops.Dataset.range(10), filetype=gen_dataset)
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
with ops.device(self._worker_device):
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
self._sess.run(iterator.initializer)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
|
@ -119,25 +119,25 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Node*> key_func_other_arguments_node;
|
||||
DataTypeVector key_func_other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
|
||||
b, captured_key_func_, &key_func_other_arguments_node,
|
||||
ctx, b, captured_key_func_, &key_func_other_arguments_node,
|
||||
&key_func_other_arguments_types));
|
||||
|
||||
std::vector<Node*> init_func_other_arguments_node;
|
||||
DataTypeVector init_func_other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
|
||||
b, captured_init_func_, &init_func_other_arguments_node,
|
||||
ctx, b, captured_init_func_, &init_func_other_arguments_node,
|
||||
&init_func_other_arguments_types));
|
||||
|
||||
std::vector<Node*> reduce_func_other_arguments_node;
|
||||
DataTypeVector reduce_func_other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
|
||||
b, captured_reduce_func_, &reduce_func_other_arguments_node,
|
||||
ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node,
|
||||
&reduce_func_other_arguments_types));
|
||||
|
||||
std::vector<Node*> finalize_func_other_arguments_node;
|
||||
DataTypeVector finalize_func_other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
|
||||
b, captured_finalize_func_, &finalize_func_other_arguments_node,
|
||||
ctx, b, captured_finalize_func_, &finalize_func_other_arguments_node,
|
||||
&finalize_func_other_arguments_types));
|
||||
|
||||
AttrValue key_func;
|
||||
@ -406,7 +406,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status OtherArgumentsNodeAndType(
|
||||
DatasetGraphDefBuilder* b,
|
||||
SerializationContext* ctx, DatasetGraphDefBuilder* b,
|
||||
const std::unique_ptr<CapturedFunction>& captured_func,
|
||||
std::vector<Node*>* other_arguments_node,
|
||||
DataTypeVector* other_arguments_types) const {
|
||||
@ -414,7 +414,13 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments_types->reserve(captured_func->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments_node->emplace_back(node);
|
||||
other_arguments_types->emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -117,20 +117,21 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Node*> key_func_other_arguments_node;
|
||||
DataTypeVector key_func_other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
|
||||
b, captured_key_func_, &key_func_other_arguments_node,
|
||||
ctx, b, captured_key_func_, &key_func_other_arguments_node,
|
||||
&key_func_other_arguments_types));
|
||||
|
||||
std::vector<Node*> reduce_func_other_arguments_node;
|
||||
DataTypeVector reduce_func_other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
|
||||
b, captured_reduce_func_, &reduce_func_other_arguments_node,
|
||||
ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node,
|
||||
&reduce_func_other_arguments_types));
|
||||
|
||||
std::vector<Node*> window_size_func_other_arguments_node;
|
||||
DataTypeVector window_size_func_other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
|
||||
b, captured_window_size_func_, &window_size_func_other_arguments_node,
|
||||
&window_size_func_other_arguments_types));
|
||||
TF_RETURN_IF_ERROR(
|
||||
OtherArgumentsNodeAndType(ctx, b, captured_window_size_func_,
|
||||
&window_size_func_other_arguments_node,
|
||||
&window_size_func_other_arguments_types));
|
||||
|
||||
AttrValue key_func;
|
||||
b->BuildAttrValue(key_func_, &key_func);
|
||||
@ -490,7 +491,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
|
||||
Status OtherArgumentsNodeAndType(
|
||||
DatasetGraphDefBuilder* b,
|
||||
SerializationContext* ctx, DatasetGraphDefBuilder* b,
|
||||
const std::unique_ptr<CapturedFunction>& captured_func,
|
||||
std::vector<Node*>* other_arguments_node,
|
||||
DataTypeVector* other_arguments_types) const {
|
||||
@ -498,7 +499,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments_types->reserve(captured_func->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments_node->emplace_back(node);
|
||||
other_arguments_types->emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -210,7 +210,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -169,7 +169,13 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -154,7 +154,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -119,7 +119,13 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments_types.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -137,7 +137,13 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -95,7 +95,13 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -121,7 +121,13 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -149,7 +149,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -160,7 +160,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -141,7 +141,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
|
@ -89,14 +89,12 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(nxt())
|
||||
else:
|
||||
# Verify that OpError is produced as expected
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
nxt = self.getNext(dataset)
|
||||
while True:
|
||||
try:
|
||||
self.evaluate(nxt())
|
||||
except errors.OutOfRangeError:
|
||||
break
|
||||
nxt = self.getNext(dataset)
|
||||
while True:
|
||||
try:
|
||||
self.evaluate(nxt())
|
||||
except errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
def _test_dataset(
|
||||
self,
|
||||
@ -110,8 +108,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
# Convert str type because py3 tf strings are bytestrings
|
||||
filenames = self._setup_files(inputs, linebreak, compression_type)
|
||||
kwargs['compression_type'] = compression_type
|
||||
dataset = readers.CsvDataset(filenames, **kwargs)
|
||||
self._verify_output_or_err(dataset, expected_output, expected_err_re)
|
||||
if expected_err_re is not None:
|
||||
# Verify that OpError is produced as expected
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
dataset = readers.CsvDataset(filenames, **kwargs)
|
||||
self._verify_output_or_err(dataset, expected_output, expected_err_re)
|
||||
else:
|
||||
dataset = readers.CsvDataset(filenames, **kwargs)
|
||||
self._verify_output_or_err(dataset, expected_output, expected_err_re)
|
||||
|
||||
def testCsvDataset_requiredFields(self):
|
||||
record_defaults = [[]] * 4
|
||||
|
@ -120,8 +120,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(dataset_fn(8, 0), expected_output=[])
|
||||
|
||||
# Empty batch should be an initialization time error.
|
||||
self.assertDatasetProduces(
|
||||
dataset_fn(0, 14), expected_error=(errors.InvalidArgumentError, ""))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.assertDatasetProduces(dataset_fn(0, 14), expected_output=[])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Even", False, False),
|
||||
|
@ -211,16 +211,15 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"v", initializer=0, use_resource=False)
|
||||
assign_op = variable.assign_add(1)
|
||||
|
||||
unoptimized_dataset = dataset_fn(variable)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.noop_elimination = True
|
||||
options.experimental_optimization.map_and_batch_fusion = True
|
||||
optimized_dataset = unoptimized_dataset.with_options(options)
|
||||
|
||||
# Check that warning is logged.
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
unoptimized_dataset = dataset_fn(variable)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.noop_elimination = True
|
||||
options.experimental_optimization.map_and_batch_fusion = True
|
||||
optimized_dataset = unoptimized_dataset.with_options(options)
|
||||
optimized_it = optimized_dataset.make_initializable_iterator()
|
||||
|
||||
self.assertGreaterEqual(len(w), 1)
|
||||
|
@ -110,13 +110,13 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that an error is raised when `driver_name` is invalid.
|
||||
def testReadResultSetWithInvalidDriverName(self):
|
||||
dataset = self._createSqlDataset(
|
||||
driver_name="sqlfake",
|
||||
query="SELECT first_name, last_name, motto FROM students "
|
||||
"ORDER BY first_name DESC",
|
||||
output_types=(dtypes.string, dtypes.string, dtypes.string))
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_error=(errors.InvalidArgumentError, ""))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = self._createSqlDataset(
|
||||
driver_name="sqlfake",
|
||||
query="SELECT first_name, last_name, motto FROM students "
|
||||
"ORDER BY first_name DESC",
|
||||
output_types=(dtypes.string, dtypes.string, dtypes.string))
|
||||
self.assertDatasetProduces(dataset, expected_output=[])
|
||||
|
||||
# Test that an error is raised when a column name in `query` is nonexistent
|
||||
def testReadResultSetWithInvalidColumnName(self):
|
||||
|
@ -197,10 +197,13 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
|
||||
|
||||
def dataset_fn():
|
||||
dataset = dataset_ops.Dataset.range(10).map(
|
||||
lambda x: array_ops.tile([x], ops.convert_to_tensor([x])))
|
||||
|
||||
def interleave_fn(_):
|
||||
return dataset_ops.Dataset.range(
|
||||
10).map(lambda x: array_ops.tile([x], ops.convert_to_tensor([x])))
|
||||
|
||||
dataset = dataset_ops.Dataset.range(1).interleave(
|
||||
lambda _: dataset,
|
||||
interleave_fn,
|
||||
cycle_length=1,
|
||||
num_parallel_calls=optimization.AUTOTUNE)
|
||||
options = dataset_ops.Options()
|
||||
|
@ -352,7 +352,6 @@ class _UnbatchDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `unbatch()` for more details."""
|
||||
super(_UnbatchDataset, self).__init__(input_dataset)
|
||||
flat_shapes = nest.flatten(input_dataset.output_shapes)
|
||||
if any(s.ndims == 0 for s in flat_shapes):
|
||||
raise ValueError("Cannot unbatch an input with scalar components.")
|
||||
@ -370,10 +369,10 @@ class _UnbatchDataset(dataset_ops.UnaryDataset):
|
||||
nest.map_structure(lambda s: s[1:], input_dataset.output_shapes),
|
||||
input_dataset.output_classes)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_unbatch_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_unbatch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -440,7 +439,6 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, batch_size, row_shape):
|
||||
"""See `Dataset.dense_to_sparse_batch()` for more details."""
|
||||
super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
|
||||
if not isinstance(input_dataset.output_types, dtypes.DType):
|
||||
raise TypeError("DenseToSparseDataset requires an input whose elements "
|
||||
"have a single component, whereas the input has %r." %
|
||||
@ -452,12 +450,13 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
|
||||
input_dataset.output_types,
|
||||
tensor_shape.vector(None).concatenate(self._row_shape))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._batch_size,
|
||||
row_shape=convert.partial_shape_to_tensor(self._row_shape),
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -499,7 +498,6 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
|
||||
ValueError: If either `output_types` or `output_shapes` is not compatible
|
||||
with the structure of `dataset`.
|
||||
"""
|
||||
super(_RestructuredDataset, self).__init__(dataset)
|
||||
self._input_dataset = dataset
|
||||
|
||||
if not allow_unsafe_cast:
|
||||
@ -539,9 +537,8 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -554,8 +551,8 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
|
||||
drop_remainder):
|
||||
"""See `Dataset.map()` for details."""
|
||||
super(_MapAndBatchDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset)
|
||||
self._batch_size_t = ops.convert_to_tensor(
|
||||
@ -573,14 +570,8 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
tensor_util.constant_value(self._batch_size_t))
|
||||
else:
|
||||
self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return ged_ops.experimental_map_and_batch_dataset(
|
||||
self._input_dataset._as_variant_tensor(),
|
||||
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
batch_size=self._batch_size_t,
|
||||
@ -588,6 +579,10 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
drop_remainder=self._drop_remainder_t,
|
||||
preserve_cardinality=True,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -47,4 +47,4 @@ def cardinality(dataset):
|
||||
the cardinality is infinite or unknown, the operation returns the named
|
||||
constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
|
||||
"""
|
||||
return ged_ops.experimental_dataset_cardinality(dataset._as_variant_tensor()) # pylint: disable=protected-access
|
||||
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
|
@ -57,10 +57,9 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `Dataset.ignore_errors()` for details."""
|
||||
super(_IgnoreErrorsDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self)))
|
||||
super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -64,5 +64,4 @@ def get_single_element(dataset):
|
||||
# pylint: disable=protected-access
|
||||
return dataset._element_structure._from_compatible_tensor_list(
|
||||
gen_dataset_ops.dataset_to_single_element(
|
||||
dataset._as_variant_tensor(),
|
||||
**dataset_ops.flat_structure(dataset)))
|
||||
dataset._variant_tensor, **dataset_ops.flat_structure(dataset)))
|
||||
|
@ -242,14 +242,23 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, key_func, reducer):
|
||||
"""See `group_by_reducer()` for details."""
|
||||
super(_GroupByReducerDataset, self).__init__(input_dataset)
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._make_key_func(key_func, input_dataset)
|
||||
self._make_init_func(reducer.init_func)
|
||||
self._make_reduce_func(reducer.reduce_func, input_dataset)
|
||||
self._make_finalize_func(reducer.finalize_func)
|
||||
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _make_key_func(self, key_func, input_dataset):
|
||||
"""Make wrapping defun for key_func."""
|
||||
@ -347,19 +356,6 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
|
||||
self._key_func, self._init_func, self._reduce_func, self._finalize_func
|
||||
]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_group_by_reducer_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._init_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._finalize_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
init_func=self._init_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**dataset_ops.flat_structure(self))
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.group_by_reducer()"
|
||||
|
||||
@ -369,13 +365,20 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
|
||||
"""See `group_by_window()` for details."""
|
||||
super(_GroupByWindowDataset, self).__init__(input_dataset)
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._make_key_func(key_func, input_dataset)
|
||||
self._make_reduce_func(reduce_func, input_dataset)
|
||||
self._make_window_size_func(window_size_func)
|
||||
variant_tensor = ged_ops.experimental_group_by_window_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _make_window_size_func(self, window_size_func):
|
||||
"""Make wrapping defun for window_size_func."""
|
||||
@ -426,17 +429,6 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
def _functions(self):
|
||||
return [self._key_func, self._reduce_func, self._window_size_func]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_group_by_window_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._key_func.function.captured_inputs,
|
||||
self._reduce_func.function.captured_inputs,
|
||||
self._window_size_func.function.captured_inputs,
|
||||
key_func=self._key_func.function,
|
||||
reduce_func=self._reduce_func.function,
|
||||
window_size_func=self._window_size_func.function,
|
||||
**dataset_ops.flat_structure(self))
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.group_by_window()"
|
||||
|
||||
|
@ -113,15 +113,15 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
data_inputs[0].output_types, output_shapes,
|
||||
data_inputs[0].output_classes)
|
||||
super(_DirectedInterleaveDataset, self).__init__()
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return (
|
||||
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
|
||||
self._selector_input._as_variant_tensor(), [
|
||||
data_input._as_variant_tensor()
|
||||
for data_input in self._data_inputs
|
||||
], **dataset_ops.flat_structure(self)))
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**dataset_ops.flat_structure(self)))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _inputs(self):
|
||||
|
@ -29,12 +29,10 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
|
||||
"""A `Dataset` that list the files according to the input patterns."""
|
||||
|
||||
def __init__(self, patterns):
|
||||
super(MatchingFilesDataset, self).__init__()
|
||||
self._patterns = ops.convert_to_tensor(
|
||||
patterns, dtype=dtypes.string, name="patterns")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_matching_files_dataset(self._patterns)
|
||||
variant_tensor = ged_ops.experimental_matching_files_dataset(self._patterns)
|
||||
super(MatchingFilesDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -105,18 +105,17 @@ class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, transformations):
|
||||
"""See `assert_next()` for details."""
|
||||
super(_AssertNextDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
if transformations is None:
|
||||
raise ValueError("At least one transformation should be specified")
|
||||
self._transformations = ops.convert_to_tensor(
|
||||
transformations, dtype=dtypes.string, name="transformations")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_assert_next_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**dataset_ops.flat_structure(self))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_assert_next_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._transformations,
|
||||
**dataset_ops.flat_structure(self)))
|
||||
super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
@ -124,10 +123,9 @@ class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `non_serializable()` for details."""
|
||||
super(_NonSerializableDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_non_serializable_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self)))
|
||||
super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -31,7 +31,6 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
|
||||
|
||||
def __init__(self, input_dataset, features, num_parallel_calls):
|
||||
super(_ParseExampleDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access
|
||||
structure.TensorStructure(dtypes.string, [None])):
|
||||
@ -81,16 +80,17 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
self._structure = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_parse_example_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**dataset_ops.flat_structure(self))
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.experimental_parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
**dataset_ops.flat_structure(self)))
|
||||
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -93,7 +93,6 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
target_device: The name of the device to which elements would be copied.
|
||||
source_device: Device where input_dataset would be placed.
|
||||
"""
|
||||
super(_CopyToDeviceDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._target_device = target_device
|
||||
spec = framework_device.DeviceSpec().from_string(self._target_device)
|
||||
@ -101,6 +100,9 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
self._source_device_string = source_device
|
||||
self._source_device = ops.convert_to_tensor(source_device)
|
||||
|
||||
wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
|
||||
self._input_dataset._variant_tensor) # pylint: disable=protected-access
|
||||
|
||||
@function.defun()
|
||||
def _init_func():
|
||||
"""Creates an iterator for the input dataset.
|
||||
@ -108,8 +110,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
Returns:
|
||||
A `string` tensor that encapsulates the iterator created.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
ds_variant = self._input_dataset._as_variant_tensor()
|
||||
ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
|
||||
resource = gen_dataset_ops.anonymous_iterator(
|
||||
**dataset_ops.flat_structure(self._input_dataset))
|
||||
with ops.control_dependencies(
|
||||
@ -195,6 +196,17 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
self._finalize_func.add_to_graph(g)
|
||||
# pylint: enable=protected-scope
|
||||
|
||||
with ops.device(self._target_device):
|
||||
variant_tensor = gen_dataset_ops.generator_dataset(
|
||||
self._init_captured_args,
|
||||
self._next_captured_args,
|
||||
self._finalize_captured_args,
|
||||
init_func=self._init_func,
|
||||
next_func=self._next_func,
|
||||
finalize_func=self._finalize_func,
|
||||
**dataset_ops.flat_structure(self._input_dataset))
|
||||
super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
# The one_shot_iterator implementation needs a 0 arg _make_dataset function
|
||||
# that thereby captures all the inputs required to create the dataset. Since
|
||||
# there are strings that are inputs to the GeneratorDataset which can't be
|
||||
@ -208,24 +220,12 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
else:
|
||||
return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
with ops.device(self._target_device):
|
||||
return gen_dataset_ops.generator_dataset(
|
||||
self._init_captured_args,
|
||||
self._next_captured_args,
|
||||
self._finalize_captured_args,
|
||||
init_func=self._init_func,
|
||||
next_func=self._next_func,
|
||||
finalize_func=self._finalize_func,
|
||||
**dataset_ops.flat_structure(self._input_dataset))
|
||||
|
||||
|
||||
class _MapOnGpuDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over elements in its using a GPU."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
|
||||
"""See `Dataset.map()` for details."""
|
||||
super(_MapOnGpuDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._use_inter_op_parallelism = use_inter_op_parallelism
|
||||
|
||||
@ -234,18 +234,16 @@ class _MapOnGpuDataset(dataset_ops.UnaryDataset):
|
||||
self._transformation_name(),
|
||||
dataset=input_dataset,
|
||||
defun_kwargs={"experimental_ints_on_device": True})
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
return ged_ops.experimental_map_dataset(
|
||||
input_t,
|
||||
variant_tensor = ged_ops.experimental_map_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -33,14 +33,10 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
|
||||
|
||||
def __init__(self, seed=None):
|
||||
"""A `Dataset` of pseudorandom values."""
|
||||
super(RandomDatasetV2, self).__init__()
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_random_dataset(
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
**dataset_ops.flat_structure(self))
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
|
||||
seed=self._seed, seed2=self._seed2, **dataset_ops.flat_structure(self))
|
||||
super(RandomDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -622,7 +622,6 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
||||
the input data. If specified, only this subset of columns will be
|
||||
parsed. Defaults to parsing all columns.
|
||||
"""
|
||||
super(CsvDatasetV2, self).__init__()
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
self._compression_type = convert.optional_param_to_tensor(
|
||||
@ -655,10 +654,7 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
||||
self._structure = structure.NestedStructure(
|
||||
tuple(structure.TensorStructure(d.dtype, [])
|
||||
for d in self._record_defaults))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# Constructs graph node for the dataset op.
|
||||
return gen_experimental_dataset_ops.experimental_csv_dataset(
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
|
||||
filenames=self._filenames,
|
||||
record_defaults=self._record_defaults,
|
||||
buffer_size=self._buffer_size,
|
||||
@ -668,8 +664,8 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
|
||||
use_quote_delim=self._use_quote_delim,
|
||||
na_value=self._na_value,
|
||||
select_cols=self._select_cols,
|
||||
compression_type=self._compression_type,
|
||||
)
|
||||
compression_type=self._compression_type)
|
||||
super(CsvDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -944,7 +940,6 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
||||
output_types: A tuple of `tf.DType` objects representing the types of the
|
||||
columns returned by `query`.
|
||||
"""
|
||||
super(SqlDatasetV2, self).__init__()
|
||||
self._driver_name = ops.convert_to_tensor(
|
||||
driver_name, dtype=dtypes.string, name="driver_name")
|
||||
self._data_source_name = ops.convert_to_tensor(
|
||||
@ -954,11 +949,10 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
|
||||
self._structure = structure.NestedStructure(
|
||||
nest.map_structure(
|
||||
lambda dtype: structure.TensorStructure(dtype, []), output_types))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_sql_dataset(
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
|
||||
self._driver_name, self._data_source_name, self._query,
|
||||
nest.flatten(self.output_types), nest.flatten(self.output_shapes))
|
||||
super(SqlDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -33,7 +33,6 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, initial_state, scan_func):
|
||||
"""See `scan()` for details."""
|
||||
super(_ScanDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
with ops.name_scope("initial_state"):
|
||||
@ -126,20 +125,18 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
self._scan_func = wrapped_func
|
||||
self._scan_func.function.add_to_graph(ops.get_default_graph())
|
||||
|
||||
def _functions(self):
|
||||
return [self._scan_func]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
input_t = self._input_dataset._as_variant_tensor()
|
||||
return gen_experimental_dataset_ops.experimental_scan_dataset(
|
||||
input_t,
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
self._state_structure._to_tensor_list(self._initial_state),
|
||||
self._scan_func.function.captured_inputs,
|
||||
f=self._scan_func.function,
|
||||
preserve_cardinality=True,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._scan_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -30,7 +30,6 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that fuses `shuffle` and `repeat`."""
|
||||
|
||||
def __init__(self, input_dataset, buffer_size, count=None, seed=None):
|
||||
super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._buffer_size = ops.convert_to_tensor(
|
||||
buffer_size, dtype=dtypes.int64, name="buffer_size")
|
||||
@ -40,18 +39,15 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
self._count = ops.convert_to_tensor(
|
||||
count, dtype=dtypes.int64, name="count")
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
input_resource = self._input_dataset._as_variant_tensor()
|
||||
return gen_dataset_ops.shuffle_and_repeat_dataset(
|
||||
input_resource,
|
||||
variant_tensor = gen_dataset_ops.shuffle_and_repeat_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
count=self._count,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
**dataset_ops.flat_structure(self))
|
||||
# pylint: enable=protected-access
|
||||
super(_ShuffleAndRepeatDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
|
||||
@tf_export("data.experimental.shuffle_and_repeat")
|
||||
|
@ -25,15 +25,13 @@ class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that sleeps before producing each upstream element."""
|
||||
|
||||
def __init__(self, input_dataset, sleep_microseconds):
|
||||
super(_SleepDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._sleep_microseconds = sleep_microseconds
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_sleep_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._sleep_microseconds,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
def sleep(sleep_microseconds):
|
||||
|
@ -102,13 +102,11 @@ class _StatsDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, and also records statistics."""
|
||||
|
||||
def __init__(self, input_dataset, op_function, tag):
|
||||
super(_StatsDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._op_function = op_function
|
||||
self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return self._op_function(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = self._op_function(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._tag,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_StatsDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -64,15 +64,13 @@ class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, and sets a custom threadpool."""
|
||||
|
||||
def __init__(self, input_dataset, thread_pool):
|
||||
super(_ThreadPoolDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._thread_pool = thread_pool
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_thread_pool_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_thread_pool_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._thread_pool._resource, # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
|
||||
|
@ -53,15 +53,13 @@ class _UniqueDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `unique()` for details."""
|
||||
super(_UniqueDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
|
||||
dtypes.string):
|
||||
raise TypeError(
|
||||
"`tf.data.experimental.unique()` only supports inputs with a single "
|
||||
"`tf.int32`, `tf.int64`, or `tf.string` component.")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_experimental_dataset_ops.experimental_unique_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
@ -57,4 +57,4 @@ class TFRecordWriter(object):
|
||||
"produces shape {0} and types {1}".format(dataset.output_shapes,
|
||||
dataset.output_types))
|
||||
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
|
||||
dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
|
||||
|
@ -91,9 +91,9 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
result = self.evaluate(get_next())
|
||||
|
||||
def testBatchDatasetInvalidBatchSize(self):
|
||||
dataset = (dataset_ops.Dataset.range(10).batch(0))
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_error=(errors.InvalidArgumentError, ''))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = (dataset_ops.Dataset.range(10).batch(0))
|
||||
self.evaluate(dataset._variant_tensor)
|
||||
|
||||
def testBatchSparse(self):
|
||||
|
||||
|
@ -139,8 +139,8 @@ class FileCacheTest(test_base.DatasetTestBase):
|
||||
self.evaluate(get_next1())
|
||||
|
||||
# Re-initialize
|
||||
get_next1 = self.getNext(cache_dataset1)
|
||||
get_next2 = self.getNext(cache_dataset2)
|
||||
get_next1 = self.getNext(cache_dataset1, requires_initialization=True)
|
||||
get_next2 = self.getNext(cache_dataset2, requires_initialization=True)
|
||||
|
||||
# Reading concurrently should succeed.
|
||||
elements_itr1 = []
|
||||
|
@ -272,12 +272,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testSkipEagerSameGraphErrorOneShot(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
dataset = dataset.batch(2)
|
||||
with test.mock.patch.object(logging, "warning") as mock_log:
|
||||
_ = dataset.make_one_shot_iterator()
|
||||
self.assertRegexpMatches(
|
||||
str(mock_log.call_args), "Please ensure that all datasets in the "
|
||||
"pipeline are created in the same graph as the iterator.")
|
||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSkipEagerSameGraphErrorOneShotSimple(self):
|
||||
@ -293,9 +289,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testSkipEagerSameGraphErrorInitializable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
dataset = dataset.batch(2)
|
||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||
_ = dataset.make_initializable_iterator()
|
||||
dataset = dataset.batch(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -36,9 +36,10 @@ class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters((-2), (-42))
|
||||
def testInvalidBufferSize(self, buffer_size):
|
||||
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_error=(errors.InvalidArgumentError, "buffer_size"))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
|
||||
self.evaluate(dataset._variant_tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -43,9 +43,9 @@ class RangeTest(test_base.DatasetTestBase):
|
||||
|
||||
def testZeroStep(self):
|
||||
start, stop, step = 2, 10, 0
|
||||
dataset = dataset_ops.Dataset.range(start, stop, step)
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_error=(errors.InvalidArgumentError, ""))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
dataset = dataset_ops.Dataset.range(start, stop, step)
|
||||
self.evaluate(dataset._variant_tensor)
|
||||
|
||||
def testNegativeStep(self):
|
||||
start, stop, step = 2, 10, -1
|
||||
|
@ -116,12 +116,11 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
("3", 14, 3, 3, 0),
|
||||
)
|
||||
def testWindowDatasetInvalid(self, count, size, shift, stride):
|
||||
dataset = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
|
||||
count).window(
|
||||
size=size, shift=shift,
|
||||
stride=stride).flat_map(lambda x: x.batch(batch_size=size))
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_error=(errors.InvalidArgumentError, ""))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window(
|
||||
size=size, shift=shift,
|
||||
stride=stride).flat_map(lambda x: x.batch(batch_size=size))
|
||||
self.evaluate(ds._variant_tensor)
|
||||
|
||||
def testWindowSparse(self):
|
||||
|
||||
|
@ -35,6 +35,7 @@ py_library(
|
||||
"//tensorflow/python/data/util:random_seed",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
"//tensorflow/python/data/util:traverse",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.data.util import options as options_lib
|
||||
from tensorflow.python.data.util import random_seed
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.data.util import structure as structure_lib
|
||||
from tensorflow.python.data.util import traverse
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -75,9 +76,27 @@ class DatasetV2(object):
|
||||
plan" of transformations that act on those elements.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, variant_tensor):
|
||||
"""Creates a DatasetV2 object.
|
||||
|
||||
This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
|
||||
take anything in its constructor whereas in the DatasetV2, we expect
|
||||
subclasses to create a variant_tensor and pass it in to the super() call.
|
||||
|
||||
Args:
|
||||
variant_tensor: A DT_VARIANT tensor that represents the dataset.
|
||||
"""
|
||||
self._dataset_variant_tensor = variant_tensor
|
||||
self._graph_attr = ops.get_default_graph()
|
||||
|
||||
@property
|
||||
def _variant_tensor(self):
|
||||
return self._dataset_variant_tensor
|
||||
|
||||
@_variant_tensor.setter
|
||||
def _variant_tensor(self, _):
|
||||
raise ValueError("The _variant_tensor property is read-only")
|
||||
|
||||
def _as_serialized_graph(self):
|
||||
"""Produces serialized graph representation of the dataset.
|
||||
|
||||
@ -85,16 +104,7 @@ class DatasetV2(object):
|
||||
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
|
||||
serialized graph.
|
||||
"""
|
||||
return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor())
|
||||
|
||||
@abc.abstractmethod
|
||||
def _as_variant_tensor(self):
|
||||
"""Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
|
||||
"""
|
||||
raise NotImplementedError("Dataset._as_variant_tensor")
|
||||
return gen_dataset_ops.dataset_to_graph(self._variant_tensor)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _inputs(self):
|
||||
@ -1279,7 +1289,7 @@ class DatasetV2(object):
|
||||
# pylint: disable=protected-access
|
||||
return state_structure._from_compatible_tensor_list(
|
||||
gen_dataset_ops.reduce_dataset(
|
||||
self._as_variant_tensor(),
|
||||
self._variant_tensor,
|
||||
state_structure._to_tensor_list(initial_state),
|
||||
reduce_func.captured_inputs,
|
||||
f=reduce_func,
|
||||
@ -1314,8 +1324,31 @@ class DatasetV1(DatasetV2):
|
||||
plan" of transformations that act on those elements.
|
||||
"""
|
||||
|
||||
def __init__(self): # pylint: disable=useless-super-delegation
|
||||
super(DatasetV1, self).__init__()
|
||||
def __init__(self):
|
||||
try:
|
||||
variant_tensor = self._as_variant_tensor()
|
||||
except AttributeError as e:
|
||||
if "_as_variant_tensor" in str(e):
|
||||
raise AttributeError("Please use _variant_tensor instead of "
|
||||
"_as_variant_tensor() to obtain the variant "
|
||||
"associated with a dataset")
|
||||
raise AttributeError("A likely cause of this error is that the super "
|
||||
"call for this dataset is not the last line of the "
|
||||
"__init__ method. The base class causes the "
|
||||
"_as_variant_tensor call in its constructor and "
|
||||
"if that uses attributes defined in the __init__ "
|
||||
"method, those attrs need to be defined before the "
|
||||
"super call.")
|
||||
super(DatasetV1, self).__init__(variant_tensor)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _as_variant_tensor(self):
|
||||
"""Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
|
||||
"""
|
||||
raise NotImplementedError("Dataset._as_variant_tensor")
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "Use `for ... in dataset:` to iterate over a dataset. If using "
|
||||
@ -1335,11 +1368,19 @@ class DatasetV1(DatasetV2):
|
||||
return iterator_ops.EagerIterator(self)
|
||||
|
||||
_ensure_same_dataset_graph(self)
|
||||
# Now that we create datasets at python object creation time, the capture
|
||||
# by value _make_dataset() function would try to capture these variant
|
||||
# tensor dataset inputs, which are marked as stateful ops and would throw
|
||||
# an error if we try and capture them. We therefore traverse the graph
|
||||
# to find all these ops and whitelist them so that the capturing
|
||||
# logic instead of throwing an error recreates these ops which is what was
|
||||
# happening before.
|
||||
all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
|
||||
graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
|
||||
|
||||
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
|
||||
# a 0-argument function.
|
||||
@function.Defun(capture_by_value=True)
|
||||
@function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops)
|
||||
def _make_dataset():
|
||||
"""Factory function for a dataset."""
|
||||
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
|
||||
@ -1351,7 +1392,7 @@ class DatasetV1(DatasetV2):
|
||||
(graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
|
||||
|
||||
dataset = self._apply_options()
|
||||
return dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
return dataset._variant_tensor # pylint: disable=protected-access
|
||||
|
||||
try:
|
||||
_make_dataset.add_to_graph(ops.get_default_graph())
|
||||
@ -1416,7 +1457,7 @@ class DatasetV1(DatasetV2):
|
||||
container="", shared_name=shared_name, **flat_structure(self))
|
||||
with ops.colocate_with(iterator_resource):
|
||||
initializer = gen_dataset_ops.make_iterator(
|
||||
dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
dataset._variant_tensor, # pylint: disable=protected-access
|
||||
iterator_resource)
|
||||
return iterator_ops.Iterator(iterator_resource, initializer,
|
||||
dataset.output_types, dataset.output_shapes,
|
||||
@ -1621,11 +1662,11 @@ class DatasetV1Adapter(DatasetV1):
|
||||
"""Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
|
||||
|
||||
def __init__(self, dataset):
|
||||
super(DatasetV1Adapter, self).__init__()
|
||||
self._dataset = dataset
|
||||
super(DatasetV1Adapter, self).__init__()
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return self._dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
return self._dataset._variant_tensor # pylint: disable=protected-access
|
||||
|
||||
def _has_captured_ref(self):
|
||||
return self._dataset._has_captured_ref() # pylint: disable=protected-access
|
||||
@ -1657,14 +1698,14 @@ def _ensure_same_dataset_graph(dataset):
|
||||
if current_graph != ds_graph:
|
||||
logging.warning("The graph (" + str(current_graph) + ") of the iterator "
|
||||
"is different from the graph (" + str(ds_graph) + ") "
|
||||
"the dataset: " + str(ds) + " was created in. "
|
||||
"If you are using the Estimator API, make sure that no "
|
||||
"part of the dataset returned by the `input_fn` function "
|
||||
"is defined outside the `input_fn` function."
|
||||
"Please ensure that all datasets in the pipeline are "
|
||||
"created in the same graph as the iterator. NOTE: This "
|
||||
"warning will become an error in future versions of "
|
||||
"TensorFlow.")
|
||||
"the dataset: " + str(ds._variant_tensor) + " was " # pylint: disable=protected-access
|
||||
"created in. If you are using the Estimator API, "
|
||||
"make sure that no part of the dataset returned by the "
|
||||
"`input_fn` function is defined outside the `input_fn` "
|
||||
"function. Please ensure that all datasets in the "
|
||||
"pipeline are created in the same graph as the iterator. "
|
||||
"NOTE: This warning will become an error in future "
|
||||
"versions of TensorFlow.")
|
||||
for input_ds in ds._inputs(): # pylint: disable=protected-access
|
||||
if input_ds not in visited:
|
||||
bfs_q.put(input_ds)
|
||||
@ -1820,9 +1861,9 @@ class DatasetSource(DatasetV2):
|
||||
class UnaryDataset(DatasetV2):
|
||||
"""Abstract class representing a dataset with one input."""
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
super(UnaryDataset, self).__init__()
|
||||
def __init__(self, input_dataset, variant_tensor):
|
||||
self._input_dataset = input_dataset
|
||||
super(UnaryDataset, self).__init__(variant_tensor)
|
||||
|
||||
def _inputs(self):
|
||||
return [self._input_dataset]
|
||||
@ -1831,6 +1872,11 @@ class UnaryDataset(DatasetV2):
|
||||
class UnaryUnchangedStructureDataset(UnaryDataset):
|
||||
"""Represents a unary dataset with the same input and output structure."""
|
||||
|
||||
def __init__(self, input_dataset, variant_tensor):
|
||||
self._input_dataset = input_dataset
|
||||
super(UnaryUnchangedStructureDataset, self).__init__(
|
||||
input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._input_dataset._element_structure # pylint: disable=protected-access
|
||||
@ -1841,7 +1887,6 @@ class TensorDataset(DatasetSource):
|
||||
|
||||
def __init__(self, tensors):
|
||||
"""See `Dataset.from_tensors()` for details."""
|
||||
super(TensorDataset, self).__init__()
|
||||
with ops.name_scope("tensors"):
|
||||
tensors = nest.pack_sequence_as(tensors, [
|
||||
sparse_tensor_lib.SparseTensor.from_value(t)
|
||||
@ -1852,9 +1897,9 @@ class TensorDataset(DatasetSource):
|
||||
self._structure = structure_lib.Structure.from_value(tensors)
|
||||
self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.tensor_dataset(
|
||||
variant_tensor = gen_dataset_ops.tensor_dataset(
|
||||
self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access
|
||||
super(TensorDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -1866,7 +1911,6 @@ class TensorSliceDataset(DatasetSource):
|
||||
|
||||
def __init__(self, tensors):
|
||||
"""See `Dataset.from_tensor_slices()` for details."""
|
||||
super(TensorSliceDataset, self).__init__()
|
||||
with ops.name_scope("tensors"):
|
||||
tensors = nest.pack_sequence_as(tensors, [
|
||||
sparse_tensor_lib.SparseTensor.from_value(t)
|
||||
@ -1887,9 +1931,9 @@ class TensorSliceDataset(DatasetSource):
|
||||
batch_dim.assert_is_compatible_with(tensor_shape.Dimension(
|
||||
tensor_shape.dimension_value(t.get_shape()[0])))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.tensor_slice_dataset(
|
||||
variant_tensor = gen_dataset_ops.tensor_slice_dataset(
|
||||
self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access
|
||||
super(TensorSliceDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -1901,7 +1945,6 @@ class SparseTensorSliceDataset(DatasetSource):
|
||||
|
||||
def __init__(self, sparse_tensor):
|
||||
"""See `Dataset.from_sparse_tensor_slices()` for details."""
|
||||
super(SparseTensorSliceDataset, self).__init__()
|
||||
if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
|
||||
raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.")
|
||||
self._sparse_tensor = sparse_tensor
|
||||
@ -1914,10 +1957,10 @@ class SparseTensorSliceDataset(DatasetSource):
|
||||
structure_lib.TensorStructure(self._sparse_tensor.dtype, [None]),
|
||||
structure_lib.TensorStructure(dtypes.int64, [rank])))
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.sparse_tensor_slice_dataset(
|
||||
variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
|
||||
self._sparse_tensor.indices, self._sparse_tensor.values,
|
||||
self._sparse_tensor.dense_shape)
|
||||
super(SparseTensorSliceDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -1928,12 +1971,8 @@ class _VariantDataset(DatasetV2):
|
||||
"""A Dataset wrapper around a `tf.variant`-typed function argument."""
|
||||
|
||||
def __init__(self, dataset_variant, structure):
|
||||
super(_VariantDataset, self).__init__()
|
||||
self._dataset_variant = dataset_variant
|
||||
self._structure = structure
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return self._dataset_variant
|
||||
super(_VariantDataset, self).__init__(dataset_variant)
|
||||
|
||||
def _inputs(self):
|
||||
return []
|
||||
@ -1965,7 +2004,7 @@ class DatasetStructure(structure_lib.Structure):
|
||||
other._element_structure))
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return [value._as_variant_tensor()] # pylint: disable=protected-access
|
||||
return [value._variant_tensor] # pylint: disable=protected-access
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
|
||||
@ -2153,7 +2192,7 @@ def flat_structure(dataset):
|
||||
Most Dataset op constructors expect `output_shapes` and `output_types`
|
||||
arguments that represent the flattened structure of an element. This helper
|
||||
function generates these attrs as a keyword argument dictionary, allowing
|
||||
`Dataset._as_variant_tensor()` implementations to pass
|
||||
`Dataset._variant_tensor` implementations to pass
|
||||
`**flat_structure(self)` to the op constructor.
|
||||
|
||||
Args:
|
||||
@ -2189,7 +2228,6 @@ class _GeneratorDataset(DatasetSource):
|
||||
`init_func` immediately before a C++ iterator over this dataset is
|
||||
destroyed. The return value is ignored.
|
||||
"""
|
||||
super(_GeneratorDataset, self).__init__()
|
||||
self._init_args = init_args
|
||||
|
||||
self._init_structure = structure_lib.Structure.from_value(init_args)
|
||||
@ -2208,9 +2246,7 @@ class _GeneratorDataset(DatasetSource):
|
||||
finalize_func,
|
||||
self._transformation_name(),
|
||||
input_structure=self._init_func.output_structure)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.generator_dataset(
|
||||
variant_tensor = gen_dataset_ops.generator_dataset(
|
||||
self._init_structure._to_tensor_list(self._init_args) # pylint: disable=protected-access
|
||||
+ self._init_func.function.captured_inputs,
|
||||
self._next_func.function.captured_inputs,
|
||||
@ -2219,6 +2255,7 @@ class _GeneratorDataset(DatasetSource):
|
||||
next_func=self._next_func.function,
|
||||
finalize_func=self._finalize_func.function,
|
||||
**flat_structure(self))
|
||||
super(_GeneratorDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -2233,7 +2270,6 @@ class ZipDataset(DatasetV2):
|
||||
|
||||
def __init__(self, datasets):
|
||||
"""See `Dataset.zip()` for details."""
|
||||
super(ZipDataset, self).__init__()
|
||||
for ds in nest.flatten(datasets):
|
||||
if not isinstance(ds, DatasetV2):
|
||||
if isinstance(ds, list):
|
||||
@ -2250,12 +2286,12 @@ class ZipDataset(DatasetV2):
|
||||
self._datasets,
|
||||
[ds._element_structure for ds in nest.flatten(self._datasets)])) # pylint: disable=protected-access
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return gen_dataset_ops.zip_dataset(
|
||||
[ds._as_variant_tensor() for ds in nest.flatten(self._datasets)],
|
||||
variant_tensor = gen_dataset_ops.zip_dataset(
|
||||
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
|
||||
**flat_structure(self))
|
||||
# pylint: enable=protected-access
|
||||
super(ZipDataset, self).__init__(variant_tensor)
|
||||
|
||||
def _inputs(self):
|
||||
return nest.flatten(self._datasets)
|
||||
@ -2270,7 +2306,6 @@ class ConcatenateDataset(DatasetV2):
|
||||
|
||||
def __init__(self, input_dataset, dataset_to_concatenate):
|
||||
"""See `Dataset.concatenate()` for details."""
|
||||
super(ConcatenateDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._dataset_to_concatenate = dataset_to_concatenate
|
||||
|
||||
@ -2298,17 +2333,15 @@ class ConcatenateDataset(DatasetV2):
|
||||
output_types, output_shapes, output_classes)
|
||||
|
||||
self._input_datasets = [input_dataset, dataset_to_concatenate]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return gen_dataset_ops.concatenate_dataset(
|
||||
self._input_dataset._as_variant_tensor(),
|
||||
self._dataset_to_concatenate._as_variant_tensor(),
|
||||
variant_tensor = gen_dataset_ops.concatenate_dataset(
|
||||
input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
|
||||
**flat_structure(self))
|
||||
# pylint: enable=protected-access
|
||||
super(ConcatenateDataset, self).__init__(variant_tensor)
|
||||
|
||||
def _inputs(self):
|
||||
return [self._input_dataset, self._dataset_to_concatenate]
|
||||
return self._input_datasets
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -2320,19 +2353,17 @@ class RepeatDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, count):
|
||||
"""See `Dataset.repeat()` for details."""
|
||||
super(RepeatDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
if count is None:
|
||||
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
|
||||
else:
|
||||
self._count = ops.convert_to_tensor(
|
||||
count, dtype=dtypes.int64, name="count")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.repeat_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.repeat_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
count=self._count,
|
||||
**flat_structure(self))
|
||||
super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class RangeDataset(DatasetSource):
|
||||
@ -2340,8 +2371,13 @@ class RangeDataset(DatasetSource):
|
||||
|
||||
def __init__(self, *args):
|
||||
"""See `Dataset.range()` for details."""
|
||||
super(RangeDataset, self).__init__()
|
||||
self._parse_args(*args)
|
||||
variant_tensor = gen_dataset_ops.range_dataset(
|
||||
start=self._start,
|
||||
stop=self._stop,
|
||||
step=self._step,
|
||||
**flat_structure(self))
|
||||
super(RangeDataset, self).__init__(variant_tensor)
|
||||
|
||||
def _parse_args(self, *args):
|
||||
"""Parse arguments according to the same rules as the `range()` builtin."""
|
||||
@ -2363,13 +2399,6 @@ class RangeDataset(DatasetSource):
|
||||
def _build_tensor(self, int64_value, name):
|
||||
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.range_dataset(
|
||||
start=self._start,
|
||||
stop=self._stop,
|
||||
step=self._step,
|
||||
**flat_structure(self))
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return structure_lib.TensorStructure(dtypes.int64, [])
|
||||
@ -2380,16 +2409,14 @@ class CacheDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, filename):
|
||||
"""See `Dataset.cache()` for details."""
|
||||
super(CacheDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._filename = ops.convert_to_tensor(
|
||||
filename, dtype=dtypes.string, name="filename")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.cache_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.cache_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
filename=self._filename,
|
||||
**flat_structure(self))
|
||||
super(CacheDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
@ -2420,7 +2447,6 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
Raises:
|
||||
ValueError: if invalid arguments are provided.
|
||||
"""
|
||||
super(ShuffleDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._buffer_size = ops.convert_to_tensor(
|
||||
buffer_size, dtype=dtypes.int64, name="buffer_size")
|
||||
@ -2430,15 +2456,14 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
|
||||
self._reshuffle_each_iteration = True
|
||||
else:
|
||||
self._reshuffle_each_iteration = reshuffle_each_iteration
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.shuffle_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.shuffle_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
reshuffle_each_iteration=self._reshuffle_each_iteration,
|
||||
**flat_structure(self))
|
||||
super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class TakeDataset(UnaryUnchangedStructureDataset):
|
||||
@ -2446,15 +2471,13 @@ class TakeDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, count):
|
||||
"""See `Dataset.take()` for details."""
|
||||
super(TakeDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.take_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.take_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
count=self._count,
|
||||
**flat_structure(self))
|
||||
super(TakeDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class SkipDataset(UnaryUnchangedStructureDataset):
|
||||
@ -2462,15 +2485,13 @@ class SkipDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, count):
|
||||
"""See `Dataset.skip()` for details."""
|
||||
super(SkipDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.skip_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.skip_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
count=self._count,
|
||||
**flat_structure(self))
|
||||
super(SkipDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class BatchDataset(UnaryDataset):
|
||||
@ -2478,7 +2499,6 @@ class BatchDataset(UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, batch_size, drop_remainder):
|
||||
"""See `Dataset.batch()` for details."""
|
||||
super(BatchDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._batch_size = ops.convert_to_tensor(
|
||||
batch_size, dtype=dtypes.int64, name="batch_size")
|
||||
@ -2494,13 +2514,12 @@ class BatchDataset(UnaryDataset):
|
||||
tensor_util.constant_value(self._batch_size))
|
||||
else:
|
||||
self._structure = input_dataset._element_structure._batch(None)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.batch_dataset_v2(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.batch_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
batch_size=self._batch_size,
|
||||
drop_remainder=self._drop_remainder,
|
||||
**flat_structure(self))
|
||||
super(BatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -2622,7 +2641,7 @@ class PaddedBatchDataset(UnaryDataset):
|
||||
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
|
||||
drop_remainder):
|
||||
"""See `Dataset.batch()` for details."""
|
||||
super(PaddedBatchDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
if sparse.any_sparse(input_dataset.output_classes):
|
||||
# TODO(b/63669786): support batching of sparse tensors
|
||||
raise TypeError(
|
||||
@ -2665,12 +2684,11 @@ class PaddedBatchDataset(UnaryDataset):
|
||||
self._input_dataset.output_types, output_shapes,
|
||||
self._input_dataset.output_classes)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
# TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
|
||||
if smart_cond.smart_constant_value(self._drop_remainder) is False:
|
||||
return gen_dataset_ops.padded_batch_dataset(
|
||||
self._input_dataset._as_variant_tensor(),
|
||||
variant_tensor = gen_dataset_ops.padded_batch_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
batch_size=self._batch_size,
|
||||
padded_shapes=[
|
||||
ops.convert_to_tensor(s, dtype=dtypes.int64)
|
||||
@ -2679,8 +2697,8 @@ class PaddedBatchDataset(UnaryDataset):
|
||||
padding_values=nest.flatten(self._padding_values),
|
||||
output_shapes=self._structure._flat_shapes)
|
||||
else:
|
||||
return gen_dataset_ops.padded_batch_dataset_v2(
|
||||
self._input_dataset._as_variant_tensor(),
|
||||
variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
batch_size=self._batch_size,
|
||||
padded_shapes=[
|
||||
ops.convert_to_tensor(s, dtype=dtypes.int64)
|
||||
@ -2689,6 +2707,7 @@ class PaddedBatchDataset(UnaryDataset):
|
||||
padding_values=nest.flatten(self._padding_values),
|
||||
drop_remainder=self._drop_remainder,
|
||||
output_shapes=self._structure._flat_shapes)
|
||||
super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -2727,22 +2746,19 @@ class MapDataset(UnaryDataset):
|
||||
use_inter_op_parallelism=True,
|
||||
preserve_cardinality=False):
|
||||
"""See `Dataset.map()` for details."""
|
||||
super(MapDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._use_inter_op_parallelism = use_inter_op_parallelism
|
||||
self._preserve_cardinality = preserve_cardinality
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
return gen_dataset_ops.map_dataset(
|
||||
input_t,
|
||||
variant_tensor = gen_dataset_ops.map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**flat_structure(self))
|
||||
super(MapDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
@ -2755,7 +2771,7 @@ class MapDataset(UnaryDataset):
|
||||
return "Dataset.map()"
|
||||
|
||||
|
||||
class ParallelMapDataset(MapDataset):
|
||||
class ParallelMapDataset(UnaryDataset):
|
||||
"""A `Dataset` that maps a function over elements in its input in parallel."""
|
||||
|
||||
def __init__(self,
|
||||
@ -2765,23 +2781,32 @@ class ParallelMapDataset(MapDataset):
|
||||
use_inter_op_parallelism=True,
|
||||
preserve_cardinality=False):
|
||||
"""See `Dataset.map()` for details."""
|
||||
super(ParallelMapDataset, self).__init__(
|
||||
input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality)
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
self._use_inter_op_parallelism = use_inter_op_parallelism
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
input_t = self._input_dataset._as_variant_tensor()
|
||||
return gen_dataset_ops.parallel_map_dataset(
|
||||
input_t,
|
||||
self._preserve_cardinality = preserve_cardinality
|
||||
variant_tensor = gen_dataset_ops.parallel_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
num_parallel_calls=self._num_parallel_calls,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**flat_structure(self))
|
||||
super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._map_func.output_structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "Dataset.map()"
|
||||
|
||||
|
||||
class FlatMapDataset(UnaryDataset):
|
||||
@ -2789,24 +2814,21 @@ class FlatMapDataset(UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, map_func):
|
||||
"""See `Dataset.flat_map()` for details."""
|
||||
super(FlatMapDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.flat_map_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.flat_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
**flat_structure(self))
|
||||
super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -2816,58 +2838,79 @@ class FlatMapDataset(UnaryDataset):
|
||||
return "Dataset.flat_map()"
|
||||
|
||||
|
||||
class InterleaveDataset(FlatMapDataset):
|
||||
class InterleaveDataset(UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and interleaves the result.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length):
|
||||
"""See `Dataset.interleave()` for details."""
|
||||
super(InterleaveDataset, self).__init__(input_dataset, map_func)
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return gen_dataset_ops.interleave_dataset(
|
||||
self._input_dataset._as_variant_tensor(),
|
||||
self._map_func.function.captured_inputs,
|
||||
variant_tensor = gen_dataset_ops.interleave_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
f=self._map_func.function,
|
||||
**flat_structure(self))
|
||||
super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "Dataset.interleave()"
|
||||
|
||||
|
||||
class ParallelInterleaveDataset(FlatMapDataset):
|
||||
class ParallelInterleaveDataset(UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and interleaves the result.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
num_parallel_calls):
|
||||
"""See `Dataset.interleave()` for details."""
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return gen_dataset_ops.parallel_interleave_dataset_v2(
|
||||
self._input_dataset._as_variant_tensor(),
|
||||
self._map_func.function.captured_inputs,
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._num_parallel_calls,
|
||||
f=self._map_func.function,
|
||||
**flat_structure(self))
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "Dataset.interleave()"
|
||||
@ -2878,7 +2921,6 @@ class FilterDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, predicate):
|
||||
"""See `Dataset.filter()` for details."""
|
||||
super(FilterDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
wrapped_func = StructuredFunctionWrapper(
|
||||
predicate, self._transformation_name(), dataset=input_dataset)
|
||||
@ -2886,16 +2928,15 @@ class FilterDataset(UnaryUnchangedStructureDataset):
|
||||
structure_lib.TensorStructure(dtypes.bool, [])):
|
||||
raise ValueError("`predicate` must return a scalar boolean tensor.")
|
||||
self._predicate = wrapped_func
|
||||
|
||||
def _functions(self):
|
||||
return [self._predicate]
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.filter_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.filter_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
other_arguments=self._predicate.function.captured_inputs,
|
||||
predicate=self._predicate.function,
|
||||
**flat_structure(self))
|
||||
super(FilterDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._predicate]
|
||||
|
||||
def _transformation_name(self):
|
||||
return "Dataset.filter()"
|
||||
@ -2906,18 +2947,16 @@ class PrefetchDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, buffer_size):
|
||||
"""See `Dataset.prefetch()` for details."""
|
||||
super(PrefetchDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
if buffer_size is None:
|
||||
buffer_size = -1 # This is the sentinel for auto-tuning.
|
||||
self._buffer_size = ops.convert_to_tensor(
|
||||
buffer_size, dtype=dtypes.int64, name="buffer_size")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.prefetch_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.prefetch_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
**flat_structure(self))
|
||||
super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class WindowDataset(UnaryDataset):
|
||||
@ -2925,7 +2964,6 @@ class WindowDataset(UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, size, shift, stride, drop_remainder):
|
||||
"""See `window_dataset()` for more details."""
|
||||
super(WindowDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
|
||||
self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
|
||||
@ -2944,15 +2982,14 @@ class WindowDataset(UnaryDataset):
|
||||
nest.flatten(input_dataset.output_types))
|
||||
])
|
||||
self._structure = structure_lib.NestedStructure(nest_of_structures)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.window_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.window_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._size,
|
||||
self._shift,
|
||||
self._stride,
|
||||
self._drop_remainder,
|
||||
**flat_structure(self))
|
||||
super(WindowDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -2963,16 +3000,14 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
|
||||
"""An identity `Dataset` that stores options."""
|
||||
|
||||
def __init__(self, input_dataset, options):
|
||||
super(_OptionsDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._options = input_dataset.options()
|
||||
if self._options:
|
||||
self._options = self._options.merge(options)
|
||||
else:
|
||||
self._options = options
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def options(self):
|
||||
return self._options
|
||||
@ -2983,13 +3018,11 @@ class _ModelDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `optimize()` for details."""
|
||||
super(_ModelDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.model_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.model_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
**flat_structure(self))
|
||||
super(_ModelDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class _OptimizeDataset(UnaryUnchangedStructureDataset):
|
||||
@ -2997,68 +3030,63 @@ class _OptimizeDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset, optimizations):
|
||||
"""See `optimize()` for details."""
|
||||
super(_OptimizeDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
if optimizations is None:
|
||||
optimizations = []
|
||||
self._optimizations = ops.convert_to_tensor(
|
||||
optimizations, dtype=dtypes.string, name="optimizations")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.optimize_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.optimize_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._optimizations,
|
||||
**flat_structure(self))
|
||||
super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, and sets a stats aggregator."""
|
||||
|
||||
def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
|
||||
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._stats_aggregator = aggregator
|
||||
self._prefix = prefix
|
||||
self._counter_prefix = counter_prefix
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_set_stats_aggregator_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._stats_aggregator._resource, # pylint: disable=protected-access
|
||||
self._prefix,
|
||||
self._counter_prefix,
|
||||
**flat_structure(self))
|
||||
super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
|
||||
class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, overriding intra-op parallelism."""
|
||||
|
||||
def __init__(self, input_dataset, max_intra_op_parallelism):
|
||||
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._max_intra_op_parallelism = ops.convert_to_tensor(
|
||||
max_intra_op_parallelism,
|
||||
dtype=dtypes.int64,
|
||||
name="max_intra_op_parallelism")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_max_intra_op_parallelism_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._max_intra_op_parallelism,
|
||||
**flat_structure(self))
|
||||
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
|
||||
class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, setting a private threadpool."""
|
||||
|
||||
def __init__(self, input_dataset, num_threads):
|
||||
super(_PrivateThreadPoolDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._num_threads = ops.convert_to_tensor(
|
||||
num_threads, dtype=dtypes.int64, name="num_threads")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_private_thread_pool_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
variant_tensor = ged_ops.experimental_private_thread_pool_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_threads,
|
||||
**flat_structure(self))
|
||||
super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
@ -357,7 +357,7 @@ class Iterator(checkpointable.CheckpointableBase):
|
||||
(self.output_shapes, dataset.output_shapes))
|
||||
with ops.colocate_with(self._iterator_resource):
|
||||
return gen_dataset_ops.make_iterator(
|
||||
dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access
|
||||
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns a nested structure of `tf.Tensor`s representing the next element.
|
||||
@ -524,7 +524,7 @@ class EagerIterator(checkpointable.CheckpointableBase):
|
||||
with ops.device("/cpu:0"):
|
||||
# pylint: disable=protected-access
|
||||
dataset = dataset._apply_options()
|
||||
ds_variant = dataset._as_variant_tensor()
|
||||
ds_variant = dataset._variant_tensor
|
||||
self._structure = structure_lib.convert_legacy_structure(
|
||||
dataset.output_types, dataset.output_shapes, dataset.output_classes)
|
||||
self._flat_output_types = self._structure._flat_types
|
||||
|
@ -30,12 +30,11 @@ from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
|
||||
|
||||
class _PerDeviceGenerator(dataset_ops.Dataset):
|
||||
class _PerDeviceGenerator(dataset_ops.DatasetV2):
|
||||
"""A `dummy` generator dataset."""
|
||||
|
||||
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
|
||||
source_device, target_device, element_structure):
|
||||
super(_PerDeviceGenerator, self).__init__()
|
||||
self._target_device = target_device
|
||||
self._structure = element_structure
|
||||
|
||||
@ -108,9 +107,8 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
|
||||
)
|
||||
self._finalize_captured_args = self._finalize_func.captured_inputs
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
with ops.device(self._target_device):
|
||||
return gen_dataset_ops.generator_dataset(
|
||||
variant_tensor = gen_dataset_ops.generator_dataset(
|
||||
self._init_captured_args,
|
||||
self._next_captured_args,
|
||||
self._finalize_captured_args,
|
||||
@ -118,6 +116,7 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
|
||||
next_func=self._next_func,
|
||||
finalize_func=self._finalize_func,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_PerDeviceGenerator, self).__init__(variant_tensor)
|
||||
|
||||
def _inputs(self):
|
||||
# TODO(b/116506223): Determine which datasets should be used as inputs here.
|
||||
@ -177,7 +176,7 @@ class MultiDeviceIterator(object):
|
||||
# The incarnation ID is used to ensure consistency between the per-device
|
||||
# iterators and the multi-device iterator.
|
||||
self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
|
||||
self._dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._multi_device_iterator_resource,
|
||||
max_buffer_size=max_buffer_size)
|
||||
|
||||
@ -200,7 +199,8 @@ class MultiDeviceIterator(object):
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
ds = ds.with_options(options)
|
||||
with ops.device(device):
|
||||
self._device_iterators.append(ds.make_initializable_iterator())
|
||||
self._device_iterators.append(
|
||||
dataset_ops.make_initializable_iterator(ds))
|
||||
|
||||
device_iterator_initializers = [
|
||||
iterator.initializer for iterator in self._device_iterators
|
||||
|
@ -49,7 +49,6 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
|
||||
to buffer. A value of 0 results in the default buffering values chosen
|
||||
based on the compression type.
|
||||
"""
|
||||
super(TextLineDatasetV2, self).__init__()
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
self._compression_type = convert.optional_param_to_tensor(
|
||||
@ -59,10 +58,9 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
|
||||
argument_dtype=dtypes.string)
|
||||
self._buffer_size = convert.optional_param_to_tensor(
|
||||
"buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.text_line_dataset(
|
||||
variant_tensor = gen_dataset_ops.text_line_dataset(
|
||||
self._filenames, self._compression_type, self._buffer_size)
|
||||
super(TextLineDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
@ -100,7 +98,6 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
|
||||
buffer_size: (Optional.) A `tf.int64` scalar representing the number of
|
||||
bytes in the read buffer. 0 means no buffering.
|
||||
"""
|
||||
super(_TFRecordDataset, self).__init__()
|
||||
# Force the type to string even if filenames is an empty list.
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtypes.string, name="filenames")
|
||||
@ -113,24 +110,32 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
|
||||
"buffer_size",
|
||||
buffer_size,
|
||||
argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.tf_record_dataset(
|
||||
variant_tensor = gen_dataset_ops.tf_record_dataset(
|
||||
self._filenames, self._compression_type, self._buffer_size)
|
||||
super(_TFRecordDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
|
||||
class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
sloppy, buffer_output_elements, prefetch_input_elements):
|
||||
"""See `tf.data.experimental.parallel_interleave()` for details."""
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func,
|
||||
cycle_length, block_length)
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure,
|
||||
dataset_ops.DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
self._buffer_output_elements = convert.optional_param_to_tensor(
|
||||
@ -141,11 +146,8 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return ged_ops.experimental_parallel_interleave_dataset(
|
||||
self._input_dataset._as_variant_tensor(),
|
||||
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
@ -154,7 +156,15 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**dataset_ops.flat_structure(self))
|
||||
# pylint: enable=protected-access
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.parallel_interleave()"
|
||||
@ -186,7 +196,6 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
||||
TypeError: If any argument does not have the expected type.
|
||||
ValueError: If any argument does not have the expected shape.
|
||||
"""
|
||||
super(TFRecordDatasetV2, self).__init__()
|
||||
if isinstance(filenames, dataset_ops.DatasetV2):
|
||||
if filenames.output_types != dtypes.string:
|
||||
raise TypeError(
|
||||
@ -215,6 +224,8 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
||||
filenames, read_one_file, cycle_length=num_parallel_reads,
|
||||
block_length=1, sloppy=False, buffer_output_elements=None,
|
||||
prefetch_input_elements=None)
|
||||
variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access
|
||||
super(TFRecordDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
def _clone(self,
|
||||
filenames=None,
|
||||
@ -226,9 +237,6 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
||||
buffer_size or self._buffer_size,
|
||||
num_parallel_reads or self._num_parallel_reads)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return self._impl._as_variant_tensor() # pylint: disable=protected-access
|
||||
|
||||
def _inputs(self):
|
||||
return self._impl._inputs() # pylint: disable=protected-access
|
||||
|
||||
@ -295,7 +303,6 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
|
||||
compression_type: (Optional.) A `tf.string` scalar evaluating to one of
|
||||
`""` (no compression), `"ZLIB"`, or `"GZIP"`.
|
||||
"""
|
||||
super(FixedLengthRecordDatasetV2, self).__init__()
|
||||
self._filenames = ops.convert_to_tensor(
|
||||
filenames, dtype=dtypes.string, name="filenames")
|
||||
self._record_bytes = ops.convert_to_tensor(
|
||||
@ -312,17 +319,16 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
|
||||
compression_type,
|
||||
argument_default="",
|
||||
argument_dtype=dtypes.string)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
if (self._compression_type is not None or
|
||||
compat.forward_compatible(2018, 11, 30)):
|
||||
return gen_dataset_ops.fixed_length_record_dataset_v2(
|
||||
variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
|
||||
self._filenames, self._header_bytes, self._record_bytes,
|
||||
self._footer_bytes, self._buffer_size, self._compression_type)
|
||||
else:
|
||||
return gen_dataset_ops.fixed_length_record_dataset(
|
||||
variant_tensor = gen_dataset_ops.fixed_length_record_dataset(
|
||||
self._filenames, self._header_bytes, self._record_bytes,
|
||||
self._footer_bytes, self._buffer_size)
|
||||
super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
|
@ -163,3 +163,24 @@ py_test(
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "traverse",
|
||||
srcs = ["traverse.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "traverse_test",
|
||||
size = "small",
|
||||
srcs = ["traverse_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":traverse",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
56
tensorflow/python/data/util/traverse.py
Normal file
56
tensorflow/python/data/util/traverse.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helpers to traverse the Dataset dependency structure."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
||||
|
||||
def obtain_all_variant_tensor_ops(dataset):
|
||||
"""Given an input dataset, finds all dataset ops used for construction.
|
||||
|
||||
A series of transformations would have created this dataset with each
|
||||
transformation including zero or more Dataset ops, each producing a dataset
|
||||
variant tensor. This method outputs all of them.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to find variant tensors for.
|
||||
|
||||
Returns:
|
||||
A list of variant_tensor producing dataset ops used to construct this
|
||||
dataset.
|
||||
"""
|
||||
all_variant_tensor_ops = []
|
||||
bfs_q = Queue.Queue()
|
||||
bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access
|
||||
visited = []
|
||||
while not bfs_q.empty():
|
||||
op = bfs_q.get()
|
||||
visited.append(op)
|
||||
# We look for all ops that produce variant tensors as output. This is a bit
|
||||
# of overkill but the other dataset _inputs() traversal strategies can't
|
||||
# cover the case of function inputs that capture dataset variants.
|
||||
# TODO(b/120873778): Make this more efficient.
|
||||
if op.outputs[0].dtype == dtypes.variant:
|
||||
all_variant_tensor_ops.append(op)
|
||||
for i in op.inputs:
|
||||
input_op = i.op
|
||||
if input_op not in visited:
|
||||
bfs_q.put(input_op)
|
||||
return all_variant_tensor_ops
|
109
tensorflow/python/data/util/traverse_test.py
Normal file
109
tensorflow/python/data/util/traverse_test.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for utilities for traversing the dataset construction graph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import traverse
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
self._input_dataset = input_dataset
|
||||
temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
|
||||
input_dataset._variant_tensor,
|
||||
buffer_size=1,
|
||||
**dataset_ops.flat_structure(self))
|
||||
variant_tensor = gen_dataset_ops.model_dataset(
|
||||
temp_variant_tensor, **dataset_ops.flat_structure(self))
|
||||
super(_TestDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class TraverseTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testOnlySource(self):
|
||||
ds = dataset_ops.Dataset.range(10)
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertAllEqual(["RangeDataset"], [x.name for x in variant_tensor_ops])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSimplePipeline(self):
|
||||
ds = dataset_ops.Dataset.range(10).map(math_ops.square)
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["MapDataset", "RangeDataset"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testConcat(self):
|
||||
ds1 = dataset_ops.Dataset.range(10)
|
||||
ds2 = dataset_ops.Dataset.range(10)
|
||||
ds = ds1.concatenate(ds2)
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZip(self):
|
||||
ds1 = dataset_ops.Dataset.range(10)
|
||||
ds2 = dataset_ops.Dataset.range(10)
|
||||
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["ZipDataset", "RangeDataset", "RangeDataset_1"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleVariantTensors(self):
|
||||
ds = dataset_ops.Dataset.range(10)
|
||||
ds = _TestDataset(ds)
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["RangeDataset", "ModelDataset", "PrefetchDataset"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFlatMap(self):
|
||||
ds1 = dataset_ops.Dataset.range(10).repeat(10)
|
||||
|
||||
def map_fn(ds):
|
||||
|
||||
def _map(x):
|
||||
return ds.batch(x)
|
||||
|
||||
return _map
|
||||
|
||||
ds2 = dataset_ops.Dataset.range(20).prefetch(1)
|
||||
ds2 = ds2.flat_map(map_fn(ds1))
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2)
|
||||
self.assertSetEqual(
|
||||
set([
|
||||
"FlatMapDataset", "PrefetchDataset", "RepeatDataset",
|
||||
"RangeDataset", "RangeDataset_1"
|
||||
]), set([x.name for x in variant_tensor_ops]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -270,6 +270,7 @@ cuda_py_test(
|
||||
":input_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/data/util:traverse",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
@ -18,15 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import filter_for_shard_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import traverse
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
# TODO(priyag): Any other reader datasets to consider here?
|
||||
_READER_DATASET_OPS = [
|
||||
"TextLineDataset", "TFRecordDataset", "FixedLengthRecordDataset",
|
||||
@ -53,100 +51,57 @@ def auto_shard_dataset(dataset, num_shards, index):
|
||||
determine a good way to shard the input dataset.
|
||||
"""
|
||||
|
||||
# TODO(priyag): Clone datasets instead of updating in place, similar to the
|
||||
# clone method for TFRecordDataset.
|
||||
def _auto_shard_impl(dataset, found_reader_op):
|
||||
"""Recursive implementation of auto sharding."""
|
||||
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
|
||||
tf_logging.warn("Autosharding is currently disabled. Please shard your input "
|
||||
"manually.")
|
||||
del num_shards, index
|
||||
return dataset
|
||||
|
||||
if not found_reader_op:
|
||||
# TODO(priyag): Make this check more robust by enforcing some common
|
||||
# property on reader datasets.
|
||||
if (isinstance(dataset, readers.TextLineDataset) or
|
||||
isinstance(dataset, readers.FixedLengthRecordDataset)):
|
||||
filenames_tensor = dataset._filenames
|
||||
num_files = array_ops.size(filenames_tensor)
|
||||
sharded_filenames_tensor = array_ops.gather(
|
||||
filenames_tensor, math_ops.range(index, num_files, num_shards))
|
||||
dataset._filenames = sharded_filenames_tensor
|
||||
return dataset
|
||||
elif isinstance(dataset, readers.TFRecordDataset):
|
||||
# `TFRecordDataset` needs to be handled separately than other readers
|
||||
# because it converts filenames to a dataset first. Also, we clone it
|
||||
# instead of updating in place because it has special logic in the
|
||||
# constructor. Eventually we will change all cases to clone datasets
|
||||
# instead of updating in-place.
|
||||
return dataset._clone(
|
||||
filenames=dataset._filenames.apply(
|
||||
filter_for_shard_ops.filter_for_shard(num_shards, index)))
|
||||
elif isinstance(dataset, dataset_ops.RangeDataset):
|
||||
return dataset.apply(
|
||||
filter_for_shard_ops.filter_for_shard(num_shards, index))
|
||||
elif hasattr(dataset, "_map_func"):
|
||||
# TODO(priyag): Make this check more robust by enforcing some common
|
||||
# property on all map/flatmap/interleave datasets.
|
||||
map_func_def = dataset._map_func.function.definition
|
||||
for node in map_func_def.node_def:
|
||||
if node.op in _READER_DATASET_OPS:
|
||||
found_reader_op = True
|
||||
break
|
||||
elif node.op == "FlatMapDataset":
|
||||
# TODO(priyag): Should this check for other map datasets? Should it
|
||||
# be recursive? It is too specific to implementation of
|
||||
# TFRecordDataset right now.
|
||||
nested_func_name = node.attr["f"].func.name
|
||||
nested_func = ops.get_default_graph()._functions[nested_func_name]
|
||||
for nested_node in nested_func.definition.node_def:
|
||||
if nested_node.op in _READER_DATASET_OPS:
|
||||
found_reader_op = True
|
||||
break
|
||||
if found_reader_op:
|
||||
break
|
||||
if found_reader_op:
|
||||
dataset._input_dataset = _auto_shard_impl(
|
||||
dataset._input_dataset, found_reader_op)
|
||||
return dataset
|
||||
|
||||
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
||||
dataset._dataset = _auto_shard_impl(
|
||||
dataset._dataset, found_reader_op)
|
||||
return dataset
|
||||
def _clone_dataset(dataset):
|
||||
"""Returns a cloned version of `dataset`."""
|
||||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
|
||||
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
|
||||
new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
|
||||
return dataset_ops._VariantDataset(new_variant_tensor,
|
||||
dataset._element_structure)
|
||||
|
||||
# TODO(priyag): Make _input_dataset(s) a common property of all datasets to
|
||||
# make this check more robust.
|
||||
if hasattr(dataset, "_input_dataset"):
|
||||
dataset._input_dataset = _auto_shard_impl(
|
||||
dataset._input_dataset, found_reader_op)
|
||||
if hasattr(dataset, "_dataset_to_concatenate"):
|
||||
# Special case for `ConcatentateDataset`. We want to shard all input
|
||||
# datasets.
|
||||
dataset._dataset_to_concatenate = _auto_shard_impl(
|
||||
dataset._dataset_to_concatenate, found_reader_op)
|
||||
return dataset
|
||||
|
||||
if hasattr(dataset, "_datasets"):
|
||||
# Special case for `ZipDataset`.
|
||||
dataset._datasets = nest.pack_sequence_as(dataset._datasets, [
|
||||
_auto_shard_impl(ds, found_reader_op)
|
||||
for ds in nest.flatten(dataset._datasets)
|
||||
])
|
||||
return dataset
|
||||
def _get_op_def(op):
|
||||
return op.op_def or op_def_registry.get_registered_ops()[op.type]
|
||||
|
||||
if not found_reader_op:
|
||||
tf_logging.warn(
|
||||
"Could not find a standard reader in the input pipeline"
|
||||
"(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)."
|
||||
"So auto-sharding is not done. Please verify correctness of "
|
||||
"auto-sharding for your input.")
|
||||
# TODO(yuefengz): maybe still shard it?
|
||||
return dataset
|
||||
|
||||
# TODO(priyag): What do we want to do if the number of filenames is
|
||||
# uneven in the number of shards? By default, this will just return as
|
||||
# many items it can before throwing OutOfRangeError.
|
||||
# TODO(priyag): This will shard the filenames before any shuffling of the
|
||||
# filename dataset. It might be desirable to shard after shuffling
|
||||
# filenames? If so, how do we achieve that?
|
||||
return dataset.apply(
|
||||
filter_for_shard_ops.filter_for_shard(num_shards, index))
|
||||
def _clone_helper(op_to_clone, variant_tensor_ops):
|
||||
"""Helper method that recursively clones `op_to_clone`.
|
||||
|
||||
return _auto_shard_impl(dataset=dataset, found_reader_op=False)
|
||||
Args:
|
||||
op_to_clone: The op we want to clone.
|
||||
variant_tensor_ops: A list of ops that we have to clone along the way.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping old_ops to new_ops created. Includes op_to_clone
|
||||
as a key.
|
||||
"""
|
||||
remap_dict = {}
|
||||
for input_tensor in op_to_clone.inputs:
|
||||
input_tensor_op = input_tensor.op
|
||||
if input_tensor_op in variant_tensor_ops:
|
||||
recursive_map = _clone_helper(input_tensor_op, variant_tensor_ops)
|
||||
remap_dict.update(recursive_map)
|
||||
inputs_list = []
|
||||
for input_tensor in op_to_clone.inputs:
|
||||
input_tensor_op = input_tensor.op
|
||||
if input_tensor_op in remap_dict:
|
||||
remapped_input = remap_dict[input_tensor_op].outputs[0]
|
||||
inputs_list.append(remapped_input)
|
||||
else:
|
||||
inputs_list.append(input_tensor_op.outputs[input_tensor.value_index])
|
||||
g = ops.get_default_graph()
|
||||
new_op = g.create_op(
|
||||
op_to_clone.type,
|
||||
inputs_list, [o.dtype for o in op_to_clone.outputs],
|
||||
name=op_to_clone.name,
|
||||
attrs=op_to_clone.node_def.attr,
|
||||
op_def=_get_op_def(op_to_clone))
|
||||
remap_dict[op_to_clone] = new_op
|
||||
return remap_dict
|
||||
|
@ -26,6 +26,8 @@ from tensorflow.python.distribute import input_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
@ -90,7 +92,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
def _verifySimpleShardingOutput(self, dataset, record_fn):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
|
||||
@ -98,7 +100,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTFRecordDataset(self):
|
||||
def DISABLED_testTFRecordDataset(self):
|
||||
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
|
||||
dataset = input_ops.auto_shard_dataset(
|
||||
dataset, self._num_shards, self._shard_index)
|
||||
@ -106,7 +108,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self._verifySimpleShardingOutput(dataset, self._record)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFlatMap(self):
|
||||
def DISABLED_testFlatMap(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
self._createTFRecordFiles())
|
||||
dataset = dataset.flat_map(readers.TFRecordDataset)
|
||||
@ -116,7 +118,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self._verifySimpleShardingOutput(dataset, self._record)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInterleave(self):
|
||||
def DISABLED_testInterleave(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
self._createTFRecordFiles())
|
||||
dataset = dataset.interleave(
|
||||
@ -129,7 +131,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self._verifySimpleShardingOutput(dataset, self._record)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testListfiles(self):
|
||||
def DISABLED_testListfiles(self):
|
||||
filenames = self._createTFRecordFiles()
|
||||
file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
|
||||
dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
|
||||
@ -139,7 +141,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
actual, expected = [], []
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
@ -150,7 +152,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self.assertAllEqual(expected, actual)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testComplexPipeline(self):
|
||||
def DISABLED_testComplexPipeline(self):
|
||||
# Setup a complex input pipeline.
|
||||
batch_size = 2
|
||||
num_epochs = 5
|
||||
@ -172,7 +174,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
# Verify output.
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
actual = []
|
||||
num_iterations = (self._num_files * self._num_records * num_epochs) // (
|
||||
self._num_shards * batch_size)
|
||||
@ -190,7 +192,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self.assertAllEqual(sorted(expected), sorted(actual))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZip(self):
|
||||
def DISABLED_testZip(self):
|
||||
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
||||
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
@ -201,7 +203,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self._verifySimpleShardingOutput(dataset, record_fn)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testConcat(self):
|
||||
def DISABLED_testConcat(self):
|
||||
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
|
||||
dataset2 = readers.TextLineDataset(self._createTextFiles())
|
||||
dataset = dataset1.concatenate(dataset2)
|
||||
@ -222,7 +224,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTextLineReader(self):
|
||||
def DISABLED_testTextLineReader(self):
|
||||
dataset = readers.TextLineDataset(self._createTextFiles())
|
||||
dataset = input_ops.auto_shard_dataset(
|
||||
dataset, self._num_shards, self._shard_index)
|
||||
@ -230,7 +232,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self._verifySimpleShardingOutput(dataset, self._text_line)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTextLineReaderWithFlatMap(self):
|
||||
def DISABLED_testTextLineReaderWithFlatMap(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
|
||||
dataset = dataset.flat_map(readers.TextLineDataset)
|
||||
dataset = input_ops.auto_shard_dataset(
|
||||
@ -239,7 +241,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self._verifySimpleShardingOutput(dataset, self._text_line)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFixedLengthReader(self):
|
||||
def DISABLED_testFixedLengthReader(self):
|
||||
dataset = readers.FixedLengthRecordDataset(
|
||||
self._createFixedLengthRecordFiles(), self._record_bytes)
|
||||
dataset = input_ops.auto_shard_dataset(
|
||||
@ -248,7 +250,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFixedLengthReaderWithFlatMap(self):
|
||||
def DISABLED_testFixedLengthReaderWithFlatMap(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
self._createFixedLengthRecordFiles())
|
||||
dataset = dataset.flat_map(
|
||||
@ -258,5 +260,77 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
|
||||
self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
|
||||
|
||||
|
||||
# A dataset that creates two variant tensors.
|
||||
class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
self._input_dataset = input_dataset
|
||||
temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
|
||||
input_dataset._variant_tensor,
|
||||
buffer_size=1,
|
||||
**dataset_ops.flat_structure(self))
|
||||
variant_tensor = gen_dataset_ops.model_dataset(
|
||||
temp_variant_tensor, **dataset_ops.flat_structure(self))
|
||||
super(_TestDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
class CloneDatasetTest(test.TestCase):
|
||||
|
||||
def _assert_datasets_equal(self, ds1, ds2):
|
||||
# First lets assert the structure is the same.
|
||||
self.assertTrue(
|
||||
ds1._element_structure.is_compatible_with(ds2._element_structure))
|
||||
self.assertTrue(
|
||||
ds2._element_structure.is_compatible_with(ds1._element_structure))
|
||||
|
||||
# Now create iterators on both and assert they produce the same values.
|
||||
it1 = dataset_ops.make_initializable_iterator(ds1)
|
||||
it2 = dataset_ops.make_initializable_iterator(ds2)
|
||||
|
||||
get_next1 = it1.get_next()
|
||||
get_next2 = it2.get_next()
|
||||
|
||||
with self.cached_session():
|
||||
self.evaluate([it1.initializer, it2.initializer])
|
||||
val1, val2 = self.evaluate([get_next1, get_next2])
|
||||
self.assertEqual(val1, val2)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testOnlySource(self):
|
||||
ds = dataset_ops.Dataset.range(10)
|
||||
cloned_ds = input_ops._clone_dataset(ds)
|
||||
self._assert_datasets_equal(ds, cloned_ds)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSimplePipeline(self):
|
||||
ds = dataset_ops.Dataset.range(10).map(math_ops.square)
|
||||
cloned_ds = input_ops._clone_dataset(ds)
|
||||
self._assert_datasets_equal(ds, cloned_ds)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testConcat(self):
|
||||
ds1 = dataset_ops.Dataset.range(10)
|
||||
ds2 = dataset_ops.Dataset.range(10)
|
||||
ds = ds1.concatenate(ds2)
|
||||
cloned_ds = input_ops._clone_dataset(ds)
|
||||
self._assert_datasets_equal(ds, cloned_ds)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZip(self):
|
||||
ds1 = dataset_ops.Dataset.range(10)
|
||||
ds2 = dataset_ops.Dataset.range(10)
|
||||
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
||||
cloned_ds = input_ops._clone_dataset(ds)
|
||||
self._assert_datasets_equal(ds, cloned_ds)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleVariantTensors(self):
|
||||
ds = dataset_ops.Dataset.range(10)
|
||||
ds = _TestDataset(ds)
|
||||
cloned_ds = input_ops._clone_dataset(ds)
|
||||
self._assert_datasets_equal(ds, cloned_ds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1600,9 +1600,9 @@ class MultiWorkerDataset(object):
|
||||
if len(dataset_fn) != input_workers.num_workers:
|
||||
raise ValueError("If `dataset_fn` is a list, it must have one entry "
|
||||
"per worker")
|
||||
if auto_shard:
|
||||
raise ValueError(
|
||||
"If `dataset_fn` is a list, `auto_shard` is not supported.")
|
||||
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
|
||||
if auto_shard:
|
||||
raise ValueError("Currently autosharding is not supported.")
|
||||
self._input_workers = input_workers
|
||||
self._datasets = []
|
||||
# TODO(yuefengz, priyag): support different set of jobs for input
|
||||
@ -1613,9 +1613,6 @@ class MultiWorkerDataset(object):
|
||||
worker_input = dataset_fn[i]()
|
||||
else:
|
||||
worker_input = dataset_fn()
|
||||
if auto_shard:
|
||||
worker_input = input_ops.auto_shard_dataset(
|
||||
worker_input, input_workers.num_workers, i)
|
||||
dataset = PerReplicaDataset(worker_input, input_workers, i,
|
||||
prefetch_on_device=prefetch_on_device)
|
||||
self._datasets.append((worker, dataset))
|
||||
@ -1805,7 +1802,11 @@ class DatasetIterator(InputIteratorImpl):
|
||||
for i, worker in enumerate(input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
worker_devices = input_workers.compute_devices_for_worker(i)
|
||||
iterator = _SingleWorkerDatasetIterator(dataset, worker, worker_devices)
|
||||
cloned_dataset = dataset
|
||||
if not context.executing_eagerly():
|
||||
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
|
||||
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
|
||||
worker_devices)
|
||||
iterators.append(iterator)
|
||||
|
||||
super(DatasetIterator, self).__init__(input_workers, iterators)
|
||||
|
@ -16,7 +16,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
|
Loading…
x
Reference in New Issue
Block a user