Create dataset kernels as we go i.e. in the __init__ method of the Dataset

class.

This removes the _as_variant_tensor() method from the DatasetV2 class (the version going to be used in TF 2.0) and replaces it with a _variant_tensor property that returns the variant_tensor representing the dataset. Also the __init__() method of DatasetV2 now takes a variant_tensor input.

For the DatasetV1 class (current API), we run the _as_variant_tensor() method in the __init__() method, so classes subclassing DatasetV1 should make their super() calls in the end.

Another implication is for Estimator code. The estimator input_fn's are supposed to be self contained and can't have ops from other graphs (like default graphs) in them. Earlier on because we didn't add anything to the graph while creating the Dataset object, this wasn't an issue but now this is a problem and the dataset creation code now needs to move into the input_fns.

A few other changes were required to make this happen
1. The make_one_shot_iterator code captures inputs by value and since now inputs to a dataset could be other datasets which are not capturable, we use the whitelisting mechanism in functions to recreate these ops.
2. The distribution strategies multi-worker code relied on dataset kernel re-creation on different devices while we created the iterator. In the new world, with the kernels already created, we now have to "clone" the dataset on different devices.
3. Auto sharding in distribution strategies is broken with this CL. For now, this CL disables it, but we can subsequently fix it using some of the cloning logic done for 2).
4. AsGraphDefInternal for functions that capture inputs that are datasets now need to be handled differently as DT_VARIANT tensors representing datasets are not serializable.

PiperOrigin-RevId: 226115500
This commit is contained in:
Rohan Jain 2018-12-18 22:03:38 -08:00 committed by TensorFlower Gardener
parent c2dfcd9cea
commit 9a63f5b843
61 changed files with 939 additions and 673 deletions

View File

@ -489,7 +489,7 @@ class BigtableTable(object):
"len(dataset.output_types))")
return gen_bigtable_ops.dataset_to_bigtable(
self._resource,
dataset._as_variant_tensor(), # pylint: disable=protected-access
dataset._variant_tensor, # pylint: disable=protected-access
column_families,
columns,
timestamp)
@ -582,13 +582,14 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource):
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
"""
def __init__(self, table):
def __init__(self, table, variant_tensor):
"""Constructs a _BigtableKeyDataset.
Args:
table: a Bigtable class.
variant_tensor: DT_VARIANT representation of the dataset.
"""
super(_BigtableKeyDataset, self).__init__()
super(_BigtableKeyDataset, self).__init__(variant_tensor)
self._table = table
@property
@ -601,13 +602,11 @@ class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
"""
def __init__(self, table, prefix):
super(_BigtablePrefixKeyDataset, self).__init__(table)
self._prefix = prefix
def _as_variant_tensor(self):
return gen_bigtable_ops.bigtable_prefix_key_dataset(
table=self._table._resource, # pylint: disable=protected-access
variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset(
table=table._resource, # pylint: disable=protected-access
prefix=self._prefix)
super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor)
class _BigtableRangeKeyDataset(_BigtableKeyDataset):
@ -615,15 +614,13 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset):
"""
def __init__(self, table, start, end):
super(_BigtableRangeKeyDataset, self).__init__(table)
self._start = start
self._end = end
def _as_variant_tensor(self):
return gen_bigtable_ops.bigtable_range_key_dataset(
table=self._table._resource, # pylint: disable=protected-access
variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset(
table=table._resource, # pylint: disable=protected-access
start_key=self._start,
end_key=self._end)
super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor)
class _BigtableSampleKeysDataset(_BigtableKeyDataset):
@ -633,11 +630,9 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset):
# TODO(saeta): Expose the data size offsets into the keys.
def __init__(self, table):
super(_BigtableSampleKeysDataset, self).__init__(table)
def _as_variant_tensor(self):
return gen_bigtable_ops.bigtable_sample_keys_dataset(
table=self._table._resource) # pylint: disable=protected-access
variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset(
table=table._resource) # pylint: disable=protected-access
super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor)
class _BigtableLookupDataset(dataset_ops.DatasetSource):
@ -651,20 +646,18 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource):
self._normalized = normalized
self._column_families = [i[0] for i in normalized]
self._columns = [i[1] for i in normalized]
variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset(
keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access
table=self._table._resource, # pylint: disable=protected-access
column_families=self._column_families,
columns=self._columns)
super(_BigtableLookupDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
return structure.NestedStructure(tuple(
[structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
def _as_variant_tensor(self):
# pylint: disable=protected-access
return gen_bigtable_ops.bigtable_lookup_dataset(
keys_dataset=self._dataset._as_variant_tensor(),
table=self._table._resource,
column_families=self._column_families,
columns=self._columns)
class _BigtableScanDataset(dataset_ops.DatasetSource):
"""_BigtableScanDataset represents a dataset that retrieves keys and values.
@ -679,14 +672,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
self._columns = [i[1] for i in normalized]
self._probability = probability
self._num_outputs = len(normalized) + 1 # 1 for row key
@property
def _element_structure(self):
return structure.NestedStructure(tuple(
[structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
def _as_variant_tensor(self):
return gen_bigtable_ops.bigtable_scan_dataset(
variant_tensor = gen_bigtable_ops.bigtable_scan_dataset(
table=self._table._resource, # pylint: disable=protected-access
prefix=self._prefix,
start_key=self._start,
@ -694,6 +680,13 @@ class _BigtableScanDataset(dataset_ops.DatasetSource):
column_families=self._column_families,
columns=self._columns,
probability=self._probability)
super(_BigtableScanDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
return structure.NestedStructure(
tuple(
[structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
@ -705,17 +698,15 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
self._prefix = prefix
self._start = start
self._end = end
variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
table=self._table._resource, # pylint: disable=protected-access
prefix=self._prefix,
start_key=self._start,
end_key=self._end)
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
return structure.NestedStructure(
(structure.TensorStructure(dtypes.string, []),
structure.TensorStructure(dtypes.string, [])))
def _as_variant_tensor(self):
# pylint: disable=protected-access
return gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
table=self._table._resource,
prefix=self._prefix,
start_key=self._start,
end_key=self._end)

View File

@ -389,13 +389,11 @@ class LMDBDataset(dataset_ops.DatasetSource):
Args:
filenames: A `tf.string` tensor containing one or more filenames.
"""
super(LMDBDataset, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_lmdb_dataset(
variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames, **dataset_ops.flat_structure(self))
super(LMDBDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):

View File

@ -30,7 +30,6 @@ class _SlideDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, window_size, window_shift, window_stride):
"""See `sliding_window_batch` for details."""
super(_SlideDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._window_size = ops.convert_to_tensor(
window_size, dtype=dtypes.int64, name="window_stride")
@ -43,14 +42,13 @@ class _SlideDataset(dataset_ops.UnaryDataset):
input_dataset.output_types, input_dataset.output_shapes,
input_dataset.output_classes)
self._structure = input_structure._batch(None) # pylint: disable=protected-access
def _as_variant_tensor(self):
return ged_ops.experimental_sliding_window_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_sliding_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
window_size=self._window_size,
window_shift=self._window_shift,
window_stride=self._window_stride,
**dataset_ops.flat_structure(self))
super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):

View File

@ -472,11 +472,11 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
for r in range(len(devices))])
def _test_dataset(self, dataset_fn, worker_devices, devices,
expected_values, auto_shard=True):
expected_values):
device_map = values.ReplicaDeviceMap(devices)
input_workers = values.InputWorkers(device_map, worker_devices)
multi_worker_dataset = values.MultiWorkerDataset(
dataset_fn, input_workers, auto_shard=auto_shard)
dataset_fn, input_workers)
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
with self.cached_session() as sess:
sess.run(multi_worker_iterator.initializer)
@ -518,16 +518,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
worker_devices, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
self._test_dataset(dataset_fn, worker_devices, devices,
[[0, 1], [2, 3], [4, 5], [6, 7]])
def testDataDistributionNoAutoShard(self):
worker_devices, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(4)
self._test_dataset(dataset_fn, worker_devices, devices,
[[0, 0], [1, 1], [2, 2], [3, 3]],
auto_shard=False)
self._test_dataset(
dataset_fn, worker_devices, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
def testDataDistributionTwoDevicePerWorker(self):
if context.num_gpus() < 1:
@ -535,8 +528,9 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
worker_devices, devices = self._cpu_and_one_gpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
self._test_dataset(dataset_fn, worker_devices, devices,
[[0, 2, 1, 3], [4, 6, 5, 7]])
self._test_dataset(
dataset_fn, worker_devices, devices,
[[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]])
def testTupleDataset(self):
worker_devices, devices = self._cpu_devices()
@ -548,9 +542,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
expected_values = [
[(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2)
]
expected_values = [[(i, i**2), (i, i**2)] for i in range(8)]
self._test_dataset(dataset_fn, worker_devices, devices,
expected_values)
@ -561,17 +553,19 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
device_map = values.ReplicaDeviceMap(devices)
input_workers = values.InputWorkers(device_map, worker_devices)
multi_worker_dataset = values.MultiWorkerDataset(
dataset_fn, input_workers, auto_shard=True)
dataset_fn, input_workers)
multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
sess.run(multi_worker_iterator.initializer)
self._test_iterator(sess, multi_worker_iterator, devices,
[[0, 1], [2, 3], [4, 5], [6, 7]])
self._test_iterator(
sess, multi_worker_iterator, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
# After re-initializing the iterator, should be able to iterate again.
sess.run(multi_worker_iterator.initializer)
self._test_iterator(sess, multi_worker_iterator, devices,
[[0, 1], [2, 3], [4, 5], [6, 7]])
self._test_iterator(
sess, multi_worker_iterator, devices,
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
def testValueErrorForIterator(self):
# Incompatiable arguments.

View File

@ -55,13 +55,11 @@ class SequenceFileDataset(dataset_ops.DatasetSource):
Args:
filenames: A `tf.string` tensor containing one or more filenames.
"""
super(SequenceFileDataset, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
def _as_variant_tensor(self):
return gen_dataset_ops.sequence_file_dataset(
variant_tensor = gen_dataset_ops.sequence_file_dataset(
self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access
super(SequenceFileDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):

View File

@ -27,6 +27,7 @@ from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.platform import test
@ -55,6 +56,7 @@ class DatasetsTest(test.TestCase):
session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def)
self._sess = session.Session(self._worker.target, config=session_config)
self._worker_device = '/job:' + worker_job.name
def testTextLineDataset(self):
all_contents = []
@ -70,7 +72,8 @@ class DatasetsTest(test.TestCase):
dataset = datasets.StreamingFilesDataset(
os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text')
iterator = dataset_ops.make_initializable_iterator(dataset)
with ops.device(self._worker_device):
iterator = dataset_ops.make_initializable_iterator(dataset)
self._sess.run(iterator.initializer)
get_next = iterator.get_next()
@ -94,7 +97,8 @@ class DatasetsTest(test.TestCase):
dataset = datasets.StreamingFilesDataset(
os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord')
iterator = dataset_ops.make_initializable_iterator(dataset)
with ops.device(self._worker_device):
iterator = dataset_ops.make_initializable_iterator(dataset)
self._sess.run(iterator.initializer)
get_next = iterator.get_next()
@ -121,7 +125,8 @@ class DatasetsTest(test.TestCase):
dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord')
iterator = dataset_ops.make_initializable_iterator(dataset)
with ops.device(self._worker_device):
iterator = dataset_ops.make_initializable_iterator(dataset)
self._sess.run(iterator.initializer)
get_next = iterator.get_next()
@ -154,7 +159,8 @@ class DatasetsTest(test.TestCase):
os.path.join(self.get_temp_dir(), 'fixed_length*'),
filetype=FixedLengthFile)
iterator = dataset_ops.make_initializable_iterator(dataset)
with ops.device(self._worker_device):
iterator = dataset_ops.make_initializable_iterator(dataset)
self._sess.run(iterator.initializer)
get_next = iterator.get_next()
@ -177,7 +183,8 @@ class DatasetsTest(test.TestCase):
dataset = datasets.StreamingFilesDataset(
dataset_ops.Dataset.range(10), filetype=gen_dataset)
iterator = dataset_ops.make_initializable_iterator(dataset)
with ops.device(self._worker_device):
iterator = dataset_ops.make_initializable_iterator(dataset)
self._sess.run(iterator.initializer)
get_next = iterator.get_next()

View File

@ -119,25 +119,25 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
std::vector<Node*> key_func_other_arguments_node;
DataTypeVector key_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_key_func_, &key_func_other_arguments_node,
ctx, b, captured_key_func_, &key_func_other_arguments_node,
&key_func_other_arguments_types));
std::vector<Node*> init_func_other_arguments_node;
DataTypeVector init_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_init_func_, &init_func_other_arguments_node,
ctx, b, captured_init_func_, &init_func_other_arguments_node,
&init_func_other_arguments_types));
std::vector<Node*> reduce_func_other_arguments_node;
DataTypeVector reduce_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_reduce_func_, &reduce_func_other_arguments_node,
ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node,
&reduce_func_other_arguments_types));
std::vector<Node*> finalize_func_other_arguments_node;
DataTypeVector finalize_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_finalize_func_, &finalize_func_other_arguments_node,
ctx, b, captured_finalize_func_, &finalize_func_other_arguments_node,
&finalize_func_other_arguments_types));
AttrValue key_func;
@ -406,7 +406,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
}
Status OtherArgumentsNodeAndType(
DatasetGraphDefBuilder* b,
SerializationContext* ctx, DatasetGraphDefBuilder* b,
const std::unique_ptr<CapturedFunction>& captured_func,
std::vector<Node*>* other_arguments_node,
DataTypeVector* other_arguments_types) const {
@ -414,7 +414,13 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types->reserve(captured_func->captured_inputs().size());
for (const Tensor& t : captured_func->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments_node->emplace_back(node);
other_arguments_types->emplace_back(t.dtype());
}

View File

@ -117,20 +117,21 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
std::vector<Node*> key_func_other_arguments_node;
DataTypeVector key_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_key_func_, &key_func_other_arguments_node,
ctx, b, captured_key_func_, &key_func_other_arguments_node,
&key_func_other_arguments_types));
std::vector<Node*> reduce_func_other_arguments_node;
DataTypeVector reduce_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_reduce_func_, &reduce_func_other_arguments_node,
ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node,
&reduce_func_other_arguments_types));
std::vector<Node*> window_size_func_other_arguments_node;
DataTypeVector window_size_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_window_size_func_, &window_size_func_other_arguments_node,
&window_size_func_other_arguments_types));
TF_RETURN_IF_ERROR(
OtherArgumentsNodeAndType(ctx, b, captured_window_size_func_,
&window_size_func_other_arguments_node,
&window_size_func_other_arguments_types));
AttrValue key_func;
b->BuildAttrValue(key_func_, &key_func);
@ -490,7 +491,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
};
Status OtherArgumentsNodeAndType(
DatasetGraphDefBuilder* b,
SerializationContext* ctx, DatasetGraphDefBuilder* b,
const std::unique_ptr<CapturedFunction>& captured_func,
std::vector<Node*>* other_arguments_node,
DataTypeVector* other_arguments_types) const {
@ -498,7 +499,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types->reserve(captured_func->captured_inputs().size());
for (const Tensor& t : captured_func->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments_node->emplace_back(node);
other_arguments_types->emplace_back(t.dtype());
}

View File

@ -210,7 +210,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -169,7 +169,13 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -154,7 +154,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -119,7 +119,13 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -137,7 +137,13 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -95,7 +95,13 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -121,7 +121,13 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -149,7 +149,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -160,7 +160,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -141,7 +141,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}

View File

@ -89,14 +89,12 @@ class CsvDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(nxt())
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
nxt = self.getNext(dataset)
while True:
try:
self.evaluate(nxt())
except errors.OutOfRangeError:
break
nxt = self.getNext(dataset)
while True:
try:
self.evaluate(nxt())
except errors.OutOfRangeError:
break
def _test_dataset(
self,
@ -110,8 +108,14 @@ class CsvDatasetTest(test_base.DatasetTestBase):
# Convert str type because py3 tf strings are bytestrings
filenames = self._setup_files(inputs, linebreak, compression_type)
kwargs['compression_type'] = compression_type
dataset = readers.CsvDataset(filenames, **kwargs)
self._verify_output_or_err(dataset, expected_output, expected_err_re)
if expected_err_re is not None:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
dataset = readers.CsvDataset(filenames, **kwargs)
self._verify_output_or_err(dataset, expected_output, expected_err_re)
else:
dataset = readers.CsvDataset(filenames, **kwargs)
self._verify_output_or_err(dataset, expected_output, expected_err_re)
def testCsvDataset_requiredFields(self):
record_defaults = [[]] * 4

View File

@ -120,8 +120,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(dataset_fn(8, 0), expected_output=[])
# Empty batch should be an initialization time error.
self.assertDatasetProduces(
dataset_fn(0, 14), expected_error=(errors.InvalidArgumentError, ""))
with self.assertRaises(errors.InvalidArgumentError):
self.assertDatasetProduces(dataset_fn(0, 14), expected_output=[])
@parameterized.named_parameters(
("Even", False, False),

View File

@ -211,16 +211,15 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
"v", initializer=0, use_resource=False)
assign_op = variable.assign_add(1)
unoptimized_dataset = dataset_fn(variable)
options = dataset_ops.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_and_batch_fusion = True
optimized_dataset = unoptimized_dataset.with_options(options)
# Check that warning is logged.
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
unoptimized_dataset = dataset_fn(variable)
options = dataset_ops.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_and_batch_fusion = True
optimized_dataset = unoptimized_dataset.with_options(options)
optimized_it = optimized_dataset.make_initializable_iterator()
self.assertGreaterEqual(len(w), 1)

View File

@ -110,13 +110,13 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that an error is raised when `driver_name` is invalid.
def testReadResultSetWithInvalidDriverName(self):
dataset = self._createSqlDataset(
driver_name="sqlfake",
query="SELECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC",
output_types=(dtypes.string, dtypes.string, dtypes.string))
self.assertDatasetProduces(
dataset, expected_error=(errors.InvalidArgumentError, ""))
with self.assertRaises(errors.InvalidArgumentError):
dataset = self._createSqlDataset(
driver_name="sqlfake",
query="SELECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC",
output_types=(dtypes.string, dtypes.string, dtypes.string))
self.assertDatasetProduces(dataset, expected_output=[])
# Test that an error is raised when a column name in `query` is nonexistent
def testReadResultSetWithInvalidColumnName(self):

View File

@ -197,10 +197,13 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
def dataset_fn():
dataset = dataset_ops.Dataset.range(10).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x])))
def interleave_fn(_):
return dataset_ops.Dataset.range(
10).map(lambda x: array_ops.tile([x], ops.convert_to_tensor([x])))
dataset = dataset_ops.Dataset.range(1).interleave(
lambda _: dataset,
interleave_fn,
cycle_length=1,
num_parallel_calls=optimization.AUTOTUNE)
options = dataset_ops.Options()

View File

@ -352,7 +352,6 @@ class _UnbatchDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
super(_UnbatchDataset, self).__init__(input_dataset)
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@ -370,10 +369,10 @@ class _UnbatchDataset(dataset_ops.UnaryDataset):
nest.map_structure(lambda s: s[1:], input_dataset.output_shapes),
input_dataset.output_classes)
def _as_variant_tensor(self):
return ged_ops.experimental_unbatch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_unbatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
@ -440,7 +439,6 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
@ -452,12 +450,13 @@ class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
input_dataset.output_types,
tensor_shape.vector(None).concatenate(self._row_shape))
def _as_variant_tensor(self):
return ged_ops.experimental_dense_to_sparse_batch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._batch_size,
row_shape=convert.partial_shape_to_tensor(self._row_shape),
**dataset_ops.flat_structure(self))
super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
variant_tensor)
@property
def _element_structure(self):
@ -499,7 +498,6 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
super(_RestructuredDataset, self).__init__(dataset)
self._input_dataset = dataset
if not allow_unsafe_cast:
@ -539,9 +537,8 @@ class _RestructuredDataset(dataset_ops.UnaryDataset):
self._structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
def _as_variant_tensor(self):
return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
@property
def _element_structure(self):
@ -554,8 +551,8 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
drop_remainder):
"""See `Dataset.map()` for details."""
super(_MapAndBatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(
map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset)
self._batch_size_t = ops.convert_to_tensor(
@ -573,14 +570,8 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
tensor_util.constant_value(self._batch_size_t))
else:
self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access
def _functions(self):
return [self._map_func]
def _as_variant_tensor(self):
# pylint: disable=protected-access
return ged_ops.experimental_map_and_batch_dataset(
self._input_dataset._as_variant_tensor(),
variant_tensor = ged_ops.experimental_map_and_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
batch_size=self._batch_size_t,
@ -588,6 +579,10 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
drop_remainder=self._drop_remainder_t,
preserve_cardinality=True,
**dataset_ops.flat_structure(self))
super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):

View File

@ -47,4 +47,4 @@ def cardinality(dataset):
the cardinality is infinite or unknown, the operation returns the named
constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
"""
return ged_ops.experimental_dataset_cardinality(dataset._as_variant_tensor()) # pylint: disable=protected-access
return ged_ops.experimental_dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access

View File

@ -57,10 +57,9 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
super(_IgnoreErrorsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
variant_tensor = (
gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**dataset_ops.flat_structure(self)))
super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -64,5 +64,4 @@ def get_single_element(dataset):
# pylint: disable=protected-access
return dataset._element_structure._from_compatible_tensor_list(
gen_dataset_ops.dataset_to_single_element(
dataset._as_variant_tensor(),
**dataset_ops.flat_structure(dataset)))
dataset._variant_tensor, **dataset_ops.flat_structure(dataset)))

View File

@ -242,14 +242,23 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
super(_GroupByReducerDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._make_key_func(key_func, input_dataset)
self._make_init_func(reducer.init_func)
self._make_reduce_func(reducer.reduce_func, input_dataset)
self._make_finalize_func(reducer.finalize_func)
variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._init_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._finalize_func.function.captured_inputs,
key_func=self._key_func.function,
init_func=self._init_func.function,
reduce_func=self._reduce_func.function,
finalize_func=self._finalize_func.function,
**dataset_ops.flat_structure(self))
super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
def _make_key_func(self, key_func, input_dataset):
"""Make wrapping defun for key_func."""
@ -347,19 +356,6 @@ class _GroupByReducerDataset(dataset_ops.UnaryDataset):
self._key_func, self._init_func, self._reduce_func, self._finalize_func
]
def _as_variant_tensor(self):
return ged_ops.experimental_group_by_reducer_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._init_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._finalize_func.function.captured_inputs,
key_func=self._key_func.function,
init_func=self._init_func.function,
reduce_func=self._reduce_func.function,
finalize_func=self._finalize_func.function,
**dataset_ops.flat_structure(self))
def _transformation_name(self):
return "tf.data.experimental.group_by_reducer()"
@ -369,13 +365,20 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
super(_GroupByWindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._make_key_func(key_func, input_dataset)
self._make_reduce_func(reduce_func, input_dataset)
self._make_window_size_func(window_size_func)
variant_tensor = ged_ops.experimental_group_by_window_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._window_size_func.function.captured_inputs,
key_func=self._key_func.function,
reduce_func=self._reduce_func.function,
window_size_func=self._window_size_func.function,
**dataset_ops.flat_structure(self))
super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
def _make_window_size_func(self, window_size_func):
"""Make wrapping defun for window_size_func."""
@ -426,17 +429,6 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
def _functions(self):
return [self._key_func, self._reduce_func, self._window_size_func]
def _as_variant_tensor(self):
return ged_ops.experimental_group_by_window_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._key_func.function.captured_inputs,
self._reduce_func.function.captured_inputs,
self._window_size_func.function.captured_inputs,
key_func=self._key_func.function,
reduce_func=self._reduce_func.function,
window_size_func=self._window_size_func.function,
**dataset_ops.flat_structure(self))
def _transformation_name(self):
return "tf.data.experimental.group_by_window()"

View File

@ -113,15 +113,15 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
self._structure = structure.convert_legacy_structure(
data_inputs[0].output_types, output_shapes,
data_inputs[0].output_classes)
super(_DirectedInterleaveDataset, self).__init__()
def _as_variant_tensor(self):
# pylint: disable=protected-access
return (
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
self._selector_input._as_variant_tensor(), [
data_input._as_variant_tensor()
for data_input in self._data_inputs
], **dataset_ops.flat_structure(self)))
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**dataset_ops.flat_structure(self)))
# pylint: enable=protected-access
def _inputs(self):

View File

@ -29,12 +29,10 @@ class MatchingFilesDataset(dataset_ops.DatasetSource):
"""A `Dataset` that list the files according to the input patterns."""
def __init__(self, patterns):
super(MatchingFilesDataset, self).__init__()
self._patterns = ops.convert_to_tensor(
patterns, dtype=dtypes.string, name="patterns")
def _as_variant_tensor(self):
return ged_ops.experimental_matching_files_dataset(self._patterns)
variant_tensor = ged_ops.experimental_matching_files_dataset(self._patterns)
super(MatchingFilesDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):

View File

@ -105,18 +105,17 @@ class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, transformations):
"""See `assert_next()` for details."""
super(_AssertNextDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if transformations is None:
raise ValueError("At least one transformation should be specified")
self._transformations = ops.convert_to_tensor(
transformations, dtype=dtypes.string, name="transformations")
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._transformations,
**dataset_ops.flat_structure(self))
variant_tensor = (
gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._transformations,
**dataset_ops.flat_structure(self)))
super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
@ -124,10 +123,9 @@ class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
"""See `non_serializable()` for details."""
super(_NonSerializableDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_non_serializable_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
variant_tensor = (
gen_experimental_dataset_ops.experimental_non_serializable_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**dataset_ops.flat_structure(self)))
super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -31,7 +31,6 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
def __init__(self, input_dataset, features, num_parallel_calls):
super(_ParseExampleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access
structure.TensorStructure(dtypes.string, [None])):
@ -81,16 +80,17 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
self._structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_parse_example_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
**dataset_ops.flat_structure(self))
variant_tensor = (
gen_experimental_dataset_ops.experimental_parse_example_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_parallel_calls,
self._dense_defaults,
self._sparse_keys,
self._dense_keys,
self._sparse_types,
self._dense_shapes,
**dataset_ops.flat_structure(self)))
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):

View File

@ -93,7 +93,6 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
target_device: The name of the device to which elements would be copied.
source_device: Device where input_dataset would be placed.
"""
super(_CopyToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._target_device = target_device
spec = framework_device.DeviceSpec().from_string(self._target_device)
@ -101,6 +100,9 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
self._source_device_string = source_device
self._source_device = ops.convert_to_tensor(source_device)
wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
self._input_dataset._variant_tensor) # pylint: disable=protected-access
@function.defun()
def _init_func():
"""Creates an iterator for the input dataset.
@ -108,8 +110,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
Returns:
A `string` tensor that encapsulates the iterator created.
"""
# pylint: disable=protected-access
ds_variant = self._input_dataset._as_variant_tensor()
ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
resource = gen_dataset_ops.anonymous_iterator(
**dataset_ops.flat_structure(self._input_dataset))
with ops.control_dependencies(
@ -195,6 +196,17 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
self._finalize_func.add_to_graph(g)
# pylint: enable=protected-scope
with ops.device(self._target_device):
variant_tensor = gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
init_func=self._init_func,
next_func=self._next_func,
finalize_func=self._finalize_func,
**dataset_ops.flat_structure(self._input_dataset))
super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
# The one_shot_iterator implementation needs a 0 arg _make_dataset function
# that thereby captures all the inputs required to create the dataset. Since
# there are strings that are inputs to the GeneratorDataset which can't be
@ -208,24 +220,12 @@ class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
else:
return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
def _as_variant_tensor(self):
with ops.device(self._target_device):
return gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
init_func=self._init_func,
next_func=self._next_func,
finalize_func=self._finalize_func,
**dataset_ops.flat_structure(self._input_dataset))
class _MapOnGpuDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over elements in its using a GPU."""
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
super(_MapOnGpuDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
@ -234,18 +234,16 @@ class _MapOnGpuDataset(dataset_ops.UnaryDataset):
self._transformation_name(),
dataset=input_dataset,
defun_kwargs={"experimental_ints_on_device": True})
def _functions(self):
return [self._map_func]
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
return ged_ops.experimental_map_dataset(
input_t,
variant_tensor = ged_ops.experimental_map_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
use_inter_op_parallelism=self._use_inter_op_parallelism,
**dataset_ops.flat_structure(self))
super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):

View File

@ -33,14 +33,10 @@ class RandomDatasetV2(dataset_ops.DatasetSource):
def __init__(self, seed=None):
"""A `Dataset` of pseudorandom values."""
super(RandomDatasetV2, self).__init__()
self._seed, self._seed2 = random_seed.get_seed(seed)
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_random_dataset(
seed=self._seed,
seed2=self._seed2,
**dataset_ops.flat_structure(self))
variant_tensor = gen_experimental_dataset_ops.experimental_random_dataset(
seed=self._seed, seed2=self._seed2, **dataset_ops.flat_structure(self))
super(RandomDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):

View File

@ -622,7 +622,6 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
the input data. If specified, only this subset of columns will be
parsed. Defaults to parsing all columns.
"""
super(CsvDatasetV2, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
self._compression_type = convert.optional_param_to_tensor(
@ -655,10 +654,7 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
self._structure = structure.NestedStructure(
tuple(structure.TensorStructure(d.dtype, [])
for d in self._record_defaults))
def _as_variant_tensor(self):
# Constructs graph node for the dataset op.
return gen_experimental_dataset_ops.experimental_csv_dataset(
variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
@ -668,8 +664,8 @@ class CsvDatasetV2(dataset_ops.DatasetSource):
use_quote_delim=self._use_quote_delim,
na_value=self._na_value,
select_cols=self._select_cols,
compression_type=self._compression_type,
)
compression_type=self._compression_type)
super(CsvDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):
@ -944,7 +940,6 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
output_types: A tuple of `tf.DType` objects representing the types of the
columns returned by `query`.
"""
super(SqlDatasetV2, self).__init__()
self._driver_name = ops.convert_to_tensor(
driver_name, dtype=dtypes.string, name="driver_name")
self._data_source_name = ops.convert_to_tensor(
@ -954,11 +949,10 @@ class SqlDatasetV2(dataset_ops.DatasetSource):
self._structure = structure.NestedStructure(
nest.map_structure(
lambda dtype: structure.TensorStructure(dtype, []), output_types))
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_sql_dataset(
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
self._driver_name, self._data_source_name, self._query,
nest.flatten(self.output_types), nest.flatten(self.output_shapes))
super(SqlDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):

View File

@ -33,7 +33,6 @@ class _ScanDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, initial_state, scan_func):
"""See `scan()` for details."""
super(_ScanDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
@ -126,20 +125,18 @@ class _ScanDataset(dataset_ops.UnaryDataset):
self._scan_func = wrapped_func
self._scan_func.function.add_to_graph(ops.get_default_graph())
def _functions(self):
return [self._scan_func]
def _as_variant_tensor(self):
# pylint: disable=protected-access
input_t = self._input_dataset._as_variant_tensor()
return gen_experimental_dataset_ops.experimental_scan_dataset(
input_t,
variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
self._input_dataset._variant_tensor,
self._state_structure._to_tensor_list(self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
**dataset_ops.flat_structure(self))
super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._scan_func]
@property
def _element_structure(self):

View File

@ -30,7 +30,6 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that fuses `shuffle` and `repeat`."""
def __init__(self, input_dataset, buffer_size, count=None, seed=None):
super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
@ -40,18 +39,15 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
self._count = ops.convert_to_tensor(
count, dtype=dtypes.int64, name="count")
self._seed, self._seed2 = random_seed.get_seed(seed)
def _as_variant_tensor(self):
# pylint: disable=protected-access
input_resource = self._input_dataset._as_variant_tensor()
return gen_dataset_ops.shuffle_and_repeat_dataset(
input_resource,
variant_tensor = gen_dataset_ops.shuffle_and_repeat_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
count=self._count,
seed=self._seed,
seed2=self._seed2,
**dataset_ops.flat_structure(self))
# pylint: enable=protected-access
super(_ShuffleAndRepeatDataset, self).__init__(input_dataset,
variant_tensor)
@tf_export("data.experimental.shuffle_and_repeat")

View File

@ -25,15 +25,13 @@ class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that sleeps before producing each upstream element."""
def __init__(self, input_dataset, sleep_microseconds):
super(_SleepDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._sleep_microseconds = sleep_microseconds
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_sleep_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._sleep_microseconds,
**dataset_ops.flat_structure(self))
super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
def sleep(sleep_microseconds):

View File

@ -102,13 +102,11 @@ class _StatsDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
def __init__(self, input_dataset, op_function, tag):
super(_StatsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._op_function = op_function
self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
def _as_variant_tensor(self):
return self._op_function(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = self._op_function(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._tag,
**dataset_ops.flat_structure(self))
super(_StatsDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -64,15 +64,13 @@ class _ThreadPoolDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, and sets a custom threadpool."""
def __init__(self, input_dataset, thread_pool):
super(_ThreadPoolDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._thread_pool = thread_pool
def _as_variant_tensor(self):
return ged_ops.experimental_thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_thread_pool_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)
# TODO(b/73383364): Properly export in the `tf.data.experimental` API when

View File

@ -53,15 +53,13 @@ class _UniqueDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
"""See `unique()` for details."""
super(_UniqueDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
raise TypeError(
"`tf.data.experimental.unique()` only supports inputs with a single "
"`tf.int32`, `tf.int64`, or `tf.string` component.")
def _as_variant_tensor(self):
return gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -57,4 +57,4 @@ class TFRecordWriter(object):
"produces shape {0} and types {1}".format(dataset.output_shapes,
dataset.output_types))
return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access

View File

@ -91,9 +91,9 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
result = self.evaluate(get_next())
def testBatchDatasetInvalidBatchSize(self):
dataset = (dataset_ops.Dataset.range(10).batch(0))
self.assertDatasetProduces(
dataset, expected_error=(errors.InvalidArgumentError, ''))
with self.assertRaises(errors.InvalidArgumentError):
dataset = (dataset_ops.Dataset.range(10).batch(0))
self.evaluate(dataset._variant_tensor)
def testBatchSparse(self):

View File

@ -139,8 +139,8 @@ class FileCacheTest(test_base.DatasetTestBase):
self.evaluate(get_next1())
# Re-initialize
get_next1 = self.getNext(cache_dataset1)
get_next2 = self.getNext(cache_dataset2)
get_next1 = self.getNext(cache_dataset1, requires_initialization=True)
get_next2 = self.getNext(cache_dataset2, requires_initialization=True)
# Reading concurrently should succeed.
elements_itr1 = []

View File

@ -272,12 +272,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testSkipEagerSameGraphErrorOneShot(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
dataset = dataset.batch(2)
with test.mock.patch.object(logging, "warning") as mock_log:
_ = dataset.make_one_shot_iterator()
self.assertRegexpMatches(
str(mock_log.call_args), "Please ensure that all datasets in the "
"pipeline are created in the same graph as the iterator.")
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
dataset = dataset.batch(2)
@test_util.run_deprecated_v1
def testSkipEagerSameGraphErrorOneShotSimple(self):
@ -293,9 +289,8 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testSkipEagerSameGraphErrorInitializable(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
dataset = dataset.batch(2)
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
_ = dataset.make_initializable_iterator()
dataset = dataset.batch(2)
if __name__ == "__main__":

View File

@ -36,9 +36,10 @@ class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((-2), (-42))
def testInvalidBufferSize(self, buffer_size):
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
self.assertDatasetProduces(
dataset, expected_error=(errors.InvalidArgumentError, "buffer_size"))
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
self.evaluate(dataset._variant_tensor)
if __name__ == "__main__":
test.main()

View File

@ -43,9 +43,9 @@ class RangeTest(test_base.DatasetTestBase):
def testZeroStep(self):
start, stop, step = 2, 10, 0
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(
dataset, expected_error=(errors.InvalidArgumentError, ""))
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(start, stop, step)
self.evaluate(dataset._variant_tensor)
def testNegativeStep(self):
start, stop, step = 2, 10, -1

View File

@ -116,12 +116,11 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
("3", 14, 3, 3, 0),
)
def testWindowDatasetInvalid(self, count, size, shift, stride):
dataset = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
count).window(
size=size, shift=shift,
stride=stride).flat_map(lambda x: x.batch(batch_size=size))
self.assertDatasetProduces(
dataset, expected_error=(errors.InvalidArgumentError, ""))
with self.assertRaises(errors.InvalidArgumentError):
ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window(
size=size, shift=shift,
stride=stride).flat_map(lambda x: x.batch(batch_size=size))
self.evaluate(ds._variant_tensor)
def testWindowSparse(self):

View File

@ -35,6 +35,7 @@ py_library(
"//tensorflow/python/data/util:random_seed",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/data/util:structure",
"//tensorflow/python/data/util:traverse",
"//third_party/py/numpy",
],
)

View File

@ -38,6 +38,7 @@ from tensorflow.python.data.util import options as options_lib
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
from tensorflow.python.data.util import structure as structure_lib
from tensorflow.python.data.util import traverse
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -75,9 +76,27 @@ class DatasetV2(object):
plan" of transformations that act on those elements.
"""
def __init__(self):
def __init__(self, variant_tensor):
"""Creates a DatasetV2 object.
This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
take anything in its constructor whereas in the DatasetV2, we expect
subclasses to create a variant_tensor and pass it in to the super() call.
Args:
variant_tensor: A DT_VARIANT tensor that represents the dataset.
"""
self._dataset_variant_tensor = variant_tensor
self._graph_attr = ops.get_default_graph()
@property
def _variant_tensor(self):
return self._dataset_variant_tensor
@_variant_tensor.setter
def _variant_tensor(self, _):
raise ValueError("The _variant_tensor property is read-only")
def _as_serialized_graph(self):
"""Produces serialized graph representation of the dataset.
@ -85,16 +104,7 @@ class DatasetV2(object):
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
serialized graph.
"""
return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor())
@abc.abstractmethod
def _as_variant_tensor(self):
"""Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
Returns:
A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
"""
raise NotImplementedError("Dataset._as_variant_tensor")
return gen_dataset_ops.dataset_to_graph(self._variant_tensor)
@abc.abstractmethod
def _inputs(self):
@ -1279,7 +1289,7 @@ class DatasetV2(object):
# pylint: disable=protected-access
return state_structure._from_compatible_tensor_list(
gen_dataset_ops.reduce_dataset(
self._as_variant_tensor(),
self._variant_tensor,
state_structure._to_tensor_list(initial_state),
reduce_func.captured_inputs,
f=reduce_func,
@ -1314,8 +1324,31 @@ class DatasetV1(DatasetV2):
plan" of transformations that act on those elements.
"""
def __init__(self): # pylint: disable=useless-super-delegation
super(DatasetV1, self).__init__()
def __init__(self):
try:
variant_tensor = self._as_variant_tensor()
except AttributeError as e:
if "_as_variant_tensor" in str(e):
raise AttributeError("Please use _variant_tensor instead of "
"_as_variant_tensor() to obtain the variant "
"associated with a dataset")
raise AttributeError("A likely cause of this error is that the super "
"call for this dataset is not the last line of the "
"__init__ method. The base class causes the "
"_as_variant_tensor call in its constructor and "
"if that uses attributes defined in the __init__ "
"method, those attrs need to be defined before the "
"super call.")
super(DatasetV1, self).__init__(variant_tensor)
@abc.abstractmethod
def _as_variant_tensor(self):
"""Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
Returns:
A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
"""
raise NotImplementedError("Dataset._as_variant_tensor")
@deprecation.deprecated(
None, "Use `for ... in dataset:` to iterate over a dataset. If using "
@ -1335,11 +1368,19 @@ class DatasetV1(DatasetV2):
return iterator_ops.EagerIterator(self)
_ensure_same_dataset_graph(self)
# Now that we create datasets at python object creation time, the capture
# by value _make_dataset() function would try to capture these variant
# tensor dataset inputs, which are marked as stateful ops and would throw
# an error if we try and capture them. We therefore traverse the graph
# to find all these ops and whitelist them so that the capturing
# logic instead of throwing an error recreates these ops which is what was
# happening before.
all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
# a 0-argument function.
@function.Defun(capture_by_value=True)
@function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops)
def _make_dataset():
"""Factory function for a dataset."""
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
@ -1351,7 +1392,7 @@ class DatasetV1(DatasetV2):
(graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
dataset = self._apply_options()
return dataset._as_variant_tensor() # pylint: disable=protected-access
return dataset._variant_tensor # pylint: disable=protected-access
try:
_make_dataset.add_to_graph(ops.get_default_graph())
@ -1416,7 +1457,7 @@ class DatasetV1(DatasetV2):
container="", shared_name=shared_name, **flat_structure(self))
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(
dataset._as_variant_tensor(), # pylint: disable=protected-access
dataset._variant_tensor, # pylint: disable=protected-access
iterator_resource)
return iterator_ops.Iterator(iterator_resource, initializer,
dataset.output_types, dataset.output_shapes,
@ -1621,11 +1662,11 @@ class DatasetV1Adapter(DatasetV1):
"""Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
def __init__(self, dataset):
super(DatasetV1Adapter, self).__init__()
self._dataset = dataset
super(DatasetV1Adapter, self).__init__()
def _as_variant_tensor(self):
return self._dataset._as_variant_tensor() # pylint: disable=protected-access
return self._dataset._variant_tensor # pylint: disable=protected-access
def _has_captured_ref(self):
return self._dataset._has_captured_ref() # pylint: disable=protected-access
@ -1657,14 +1698,14 @@ def _ensure_same_dataset_graph(dataset):
if current_graph != ds_graph:
logging.warning("The graph (" + str(current_graph) + ") of the iterator "
"is different from the graph (" + str(ds_graph) + ") "
"the dataset: " + str(ds) + " was created in. "
"If you are using the Estimator API, make sure that no "
"part of the dataset returned by the `input_fn` function "
"is defined outside the `input_fn` function."
"Please ensure that all datasets in the pipeline are "
"created in the same graph as the iterator. NOTE: This "
"warning will become an error in future versions of "
"TensorFlow.")
"the dataset: " + str(ds._variant_tensor) + " was " # pylint: disable=protected-access
"created in. If you are using the Estimator API, "
"make sure that no part of the dataset returned by the "
"`input_fn` function is defined outside the `input_fn` "
"function. Please ensure that all datasets in the "
"pipeline are created in the same graph as the iterator. "
"NOTE: This warning will become an error in future "
"versions of TensorFlow.")
for input_ds in ds._inputs(): # pylint: disable=protected-access
if input_ds not in visited:
bfs_q.put(input_ds)
@ -1820,9 +1861,9 @@ class DatasetSource(DatasetV2):
class UnaryDataset(DatasetV2):
"""Abstract class representing a dataset with one input."""
def __init__(self, input_dataset):
super(UnaryDataset, self).__init__()
def __init__(self, input_dataset, variant_tensor):
self._input_dataset = input_dataset
super(UnaryDataset, self).__init__(variant_tensor)
def _inputs(self):
return [self._input_dataset]
@ -1831,6 +1872,11 @@ class UnaryDataset(DatasetV2):
class UnaryUnchangedStructureDataset(UnaryDataset):
"""Represents a unary dataset with the same input and output structure."""
def __init__(self, input_dataset, variant_tensor):
self._input_dataset = input_dataset
super(UnaryUnchangedStructureDataset, self).__init__(
input_dataset, variant_tensor)
@property
def _element_structure(self):
return self._input_dataset._element_structure # pylint: disable=protected-access
@ -1841,7 +1887,6 @@ class TensorDataset(DatasetSource):
def __init__(self, tensors):
"""See `Dataset.from_tensors()` for details."""
super(TensorDataset, self).__init__()
with ops.name_scope("tensors"):
tensors = nest.pack_sequence_as(tensors, [
sparse_tensor_lib.SparseTensor.from_value(t)
@ -1852,9 +1897,9 @@ class TensorDataset(DatasetSource):
self._structure = structure_lib.Structure.from_value(tensors)
self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access
def _as_variant_tensor(self):
return gen_dataset_ops.tensor_dataset(
variant_tensor = gen_dataset_ops.tensor_dataset(
self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access
super(TensorDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
@ -1866,7 +1911,6 @@ class TensorSliceDataset(DatasetSource):
def __init__(self, tensors):
"""See `Dataset.from_tensor_slices()` for details."""
super(TensorSliceDataset, self).__init__()
with ops.name_scope("tensors"):
tensors = nest.pack_sequence_as(tensors, [
sparse_tensor_lib.SparseTensor.from_value(t)
@ -1887,9 +1931,9 @@ class TensorSliceDataset(DatasetSource):
batch_dim.assert_is_compatible_with(tensor_shape.Dimension(
tensor_shape.dimension_value(t.get_shape()[0])))
def _as_variant_tensor(self):
return gen_dataset_ops.tensor_slice_dataset(
variant_tensor = gen_dataset_ops.tensor_slice_dataset(
self._tensors, output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access
super(TensorSliceDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
@ -1901,7 +1945,6 @@ class SparseTensorSliceDataset(DatasetSource):
def __init__(self, sparse_tensor):
"""See `Dataset.from_sparse_tensor_slices()` for details."""
super(SparseTensorSliceDataset, self).__init__()
if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.")
self._sparse_tensor = sparse_tensor
@ -1914,10 +1957,10 @@ class SparseTensorSliceDataset(DatasetSource):
structure_lib.TensorStructure(self._sparse_tensor.dtype, [None]),
structure_lib.TensorStructure(dtypes.int64, [rank])))
def _as_variant_tensor(self):
return gen_dataset_ops.sparse_tensor_slice_dataset(
variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
self._sparse_tensor.indices, self._sparse_tensor.values,
self._sparse_tensor.dense_shape)
super(SparseTensorSliceDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
@ -1928,12 +1971,8 @@ class _VariantDataset(DatasetV2):
"""A Dataset wrapper around a `tf.variant`-typed function argument."""
def __init__(self, dataset_variant, structure):
super(_VariantDataset, self).__init__()
self._dataset_variant = dataset_variant
self._structure = structure
def _as_variant_tensor(self):
return self._dataset_variant
super(_VariantDataset, self).__init__(dataset_variant)
def _inputs(self):
return []
@ -1965,7 +2004,7 @@ class DatasetStructure(structure_lib.Structure):
other._element_structure))
def _to_tensor_list(self, value):
return [value._as_variant_tensor()] # pylint: disable=protected-access
return [value._variant_tensor] # pylint: disable=protected-access
def _to_batched_tensor_list(self, value):
raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
@ -2153,7 +2192,7 @@ def flat_structure(dataset):
Most Dataset op constructors expect `output_shapes` and `output_types`
arguments that represent the flattened structure of an element. This helper
function generates these attrs as a keyword argument dictionary, allowing
`Dataset._as_variant_tensor()` implementations to pass
`Dataset._variant_tensor` implementations to pass
`**flat_structure(self)` to the op constructor.
Args:
@ -2189,7 +2228,6 @@ class _GeneratorDataset(DatasetSource):
`init_func` immediately before a C++ iterator over this dataset is
destroyed. The return value is ignored.
"""
super(_GeneratorDataset, self).__init__()
self._init_args = init_args
self._init_structure = structure_lib.Structure.from_value(init_args)
@ -2208,9 +2246,7 @@ class _GeneratorDataset(DatasetSource):
finalize_func,
self._transformation_name(),
input_structure=self._init_func.output_structure)
def _as_variant_tensor(self):
return gen_dataset_ops.generator_dataset(
variant_tensor = gen_dataset_ops.generator_dataset(
self._init_structure._to_tensor_list(self._init_args) # pylint: disable=protected-access
+ self._init_func.function.captured_inputs,
self._next_func.function.captured_inputs,
@ -2219,6 +2255,7 @@ class _GeneratorDataset(DatasetSource):
next_func=self._next_func.function,
finalize_func=self._finalize_func.function,
**flat_structure(self))
super(_GeneratorDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
@ -2233,7 +2270,6 @@ class ZipDataset(DatasetV2):
def __init__(self, datasets):
"""See `Dataset.zip()` for details."""
super(ZipDataset, self).__init__()
for ds in nest.flatten(datasets):
if not isinstance(ds, DatasetV2):
if isinstance(ds, list):
@ -2250,12 +2286,12 @@ class ZipDataset(DatasetV2):
self._datasets,
[ds._element_structure for ds in nest.flatten(self._datasets)])) # pylint: disable=protected-access
def _as_variant_tensor(self):
# pylint: disable=protected-access
return gen_dataset_ops.zip_dataset(
[ds._as_variant_tensor() for ds in nest.flatten(self._datasets)],
variant_tensor = gen_dataset_ops.zip_dataset(
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
**flat_structure(self))
# pylint: enable=protected-access
super(ZipDataset, self).__init__(variant_tensor)
def _inputs(self):
return nest.flatten(self._datasets)
@ -2270,7 +2306,6 @@ class ConcatenateDataset(DatasetV2):
def __init__(self, input_dataset, dataset_to_concatenate):
"""See `Dataset.concatenate()` for details."""
super(ConcatenateDataset, self).__init__()
self._input_dataset = input_dataset
self._dataset_to_concatenate = dataset_to_concatenate
@ -2298,17 +2333,15 @@ class ConcatenateDataset(DatasetV2):
output_types, output_shapes, output_classes)
self._input_datasets = [input_dataset, dataset_to_concatenate]
def _as_variant_tensor(self):
# pylint: disable=protected-access
return gen_dataset_ops.concatenate_dataset(
self._input_dataset._as_variant_tensor(),
self._dataset_to_concatenate._as_variant_tensor(),
variant_tensor = gen_dataset_ops.concatenate_dataset(
input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
**flat_structure(self))
# pylint: enable=protected-access
super(ConcatenateDataset, self).__init__(variant_tensor)
def _inputs(self):
return [self._input_dataset, self._dataset_to_concatenate]
return self._input_datasets
@property
def _element_structure(self):
@ -2320,19 +2353,17 @@ class RepeatDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, count):
"""See `Dataset.repeat()` for details."""
super(RepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
else:
self._count = ops.convert_to_tensor(
count, dtype=dtypes.int64, name="count")
def _as_variant_tensor(self):
return gen_dataset_ops.repeat_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.repeat_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
count=self._count,
**flat_structure(self))
super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
class RangeDataset(DatasetSource):
@ -2340,8 +2371,13 @@ class RangeDataset(DatasetSource):
def __init__(self, *args):
"""See `Dataset.range()` for details."""
super(RangeDataset, self).__init__()
self._parse_args(*args)
variant_tensor = gen_dataset_ops.range_dataset(
start=self._start,
stop=self._stop,
step=self._step,
**flat_structure(self))
super(RangeDataset, self).__init__(variant_tensor)
def _parse_args(self, *args):
"""Parse arguments according to the same rules as the `range()` builtin."""
@ -2363,13 +2399,6 @@ class RangeDataset(DatasetSource):
def _build_tensor(self, int64_value, name):
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
def _as_variant_tensor(self):
return gen_dataset_ops.range_dataset(
start=self._start,
stop=self._stop,
step=self._step,
**flat_structure(self))
@property
def _element_structure(self):
return structure_lib.TensorStructure(dtypes.int64, [])
@ -2380,16 +2409,14 @@ class CacheDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, filename):
"""See `Dataset.cache()` for details."""
super(CacheDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._filename = ops.convert_to_tensor(
filename, dtype=dtypes.string, name="filename")
def _as_variant_tensor(self):
return gen_dataset_ops.cache_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.cache_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
filename=self._filename,
**flat_structure(self))
super(CacheDataset, self).__init__(input_dataset, variant_tensor)
class ShuffleDataset(UnaryUnchangedStructureDataset):
@ -2420,7 +2447,6 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
Raises:
ValueError: if invalid arguments are provided.
"""
super(ShuffleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
@ -2430,15 +2456,14 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
self._reshuffle_each_iteration = True
else:
self._reshuffle_each_iteration = reshuffle_each_iteration
def _as_variant_tensor(self):
return gen_dataset_ops.shuffle_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.shuffle_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
seed=self._seed,
seed2=self._seed2,
reshuffle_each_iteration=self._reshuffle_each_iteration,
**flat_structure(self))
super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
class TakeDataset(UnaryUnchangedStructureDataset):
@ -2446,15 +2471,13 @@ class TakeDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, count):
"""See `Dataset.take()` for details."""
super(TakeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
def _as_variant_tensor(self):
return gen_dataset_ops.take_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.take_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
count=self._count,
**flat_structure(self))
super(TakeDataset, self).__init__(input_dataset, variant_tensor)
class SkipDataset(UnaryUnchangedStructureDataset):
@ -2462,15 +2485,13 @@ class SkipDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, count):
"""See `Dataset.skip()` for details."""
super(SkipDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
def _as_variant_tensor(self):
return gen_dataset_ops.skip_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.skip_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
count=self._count,
**flat_structure(self))
super(SkipDataset, self).__init__(input_dataset, variant_tensor)
class BatchDataset(UnaryDataset):
@ -2478,7 +2499,6 @@ class BatchDataset(UnaryDataset):
def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
super(BatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
@ -2494,13 +2514,12 @@ class BatchDataset(UnaryDataset):
tensor_util.constant_value(self._batch_size))
else:
self._structure = input_dataset._element_structure._batch(None)
def _as_variant_tensor(self):
return gen_dataset_ops.batch_dataset_v2(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.batch_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
batch_size=self._batch_size,
drop_remainder=self._drop_remainder,
**flat_structure(self))
super(BatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
@ -2622,7 +2641,7 @@ class PaddedBatchDataset(UnaryDataset):
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
drop_remainder):
"""See `Dataset.batch()` for details."""
super(PaddedBatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if sparse.any_sparse(input_dataset.output_classes):
# TODO(b/63669786): support batching of sparse tensors
raise TypeError(
@ -2665,12 +2684,11 @@ class PaddedBatchDataset(UnaryDataset):
self._input_dataset.output_types, output_shapes,
self._input_dataset.output_classes)
def _as_variant_tensor(self):
# pylint: disable=protected-access
# TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
if smart_cond.smart_constant_value(self._drop_remainder) is False:
return gen_dataset_ops.padded_batch_dataset(
self._input_dataset._as_variant_tensor(),
variant_tensor = gen_dataset_ops.padded_batch_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
batch_size=self._batch_size,
padded_shapes=[
ops.convert_to_tensor(s, dtype=dtypes.int64)
@ -2679,8 +2697,8 @@ class PaddedBatchDataset(UnaryDataset):
padding_values=nest.flatten(self._padding_values),
output_shapes=self._structure._flat_shapes)
else:
return gen_dataset_ops.padded_batch_dataset_v2(
self._input_dataset._as_variant_tensor(),
variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
batch_size=self._batch_size,
padded_shapes=[
ops.convert_to_tensor(s, dtype=dtypes.int64)
@ -2689,6 +2707,7 @@ class PaddedBatchDataset(UnaryDataset):
padding_values=nest.flatten(self._padding_values),
drop_remainder=self._drop_remainder,
output_shapes=self._structure._flat_shapes)
super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
@ -2727,22 +2746,19 @@ class MapDataset(UnaryDataset):
use_inter_op_parallelism=True,
preserve_cardinality=False):
"""See `Dataset.map()` for details."""
super(MapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
self._preserve_cardinality = preserve_cardinality
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
return gen_dataset_ops.map_dataset(
input_t,
variant_tensor = gen_dataset_ops.map_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
use_inter_op_parallelism=self._use_inter_op_parallelism,
preserve_cardinality=self._preserve_cardinality,
**flat_structure(self))
super(MapDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._map_func]
@ -2755,7 +2771,7 @@ class MapDataset(UnaryDataset):
return "Dataset.map()"
class ParallelMapDataset(MapDataset):
class ParallelMapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input in parallel."""
def __init__(self,
@ -2765,23 +2781,32 @@ class ParallelMapDataset(MapDataset):
use_inter_op_parallelism=True,
preserve_cardinality=False):
"""See `Dataset.map()` for details."""
super(ParallelMapDataset, self).__init__(
input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
def _as_variant_tensor(self):
# pylint: disable=protected-access
input_t = self._input_dataset._as_variant_tensor()
return gen_dataset_ops.parallel_map_dataset(
input_t,
self._preserve_cardinality = preserve_cardinality
variant_tensor = gen_dataset_ops.parallel_map_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
num_parallel_calls=self._num_parallel_calls,
use_inter_op_parallelism=self._use_inter_op_parallelism,
preserve_cardinality=self._preserve_cardinality,
**flat_structure(self))
super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):
return self._map_func.output_structure
def _transformation_name(self):
return "Dataset.map()"
class FlatMapDataset(UnaryDataset):
@ -2789,24 +2814,21 @@ class FlatMapDataset(UnaryDataset):
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
super(FlatMapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
def _functions(self):
return [self._map_func]
def _as_variant_tensor(self):
return gen_dataset_ops.flat_map_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.flat_map_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
**flat_structure(self))
super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):
@ -2816,58 +2838,79 @@ class FlatMapDataset(UnaryDataset):
return "Dataset.flat_map()"
class InterleaveDataset(FlatMapDataset):
class InterleaveDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and interleaves the result.
"""
def __init__(self, input_dataset, map_func, cycle_length, block_length):
"""See `Dataset.interleave()` for details."""
super(InterleaveDataset, self).__init__(input_dataset, map_func)
self._input_dataset = input_dataset
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
def _as_variant_tensor(self):
# pylint: disable=protected-access
return gen_dataset_ops.interleave_dataset(
self._input_dataset._as_variant_tensor(),
self._map_func.function.captured_inputs,
variant_tensor = gen_dataset_ops.interleave_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, # pylint: disable=protected-access
self._cycle_length,
self._block_length,
f=self._map_func.function,
**flat_structure(self))
super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):
return self._structure
def _transformation_name(self):
return "Dataset.interleave()"
class ParallelInterleaveDataset(FlatMapDataset):
class ParallelInterleaveDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and interleaves the result.
"""
def __init__(self, input_dataset, map_func, cycle_length, block_length,
num_parallel_calls):
"""See `Dataset.interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
self._input_dataset = input_dataset
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
def _as_variant_tensor(self):
# pylint: disable=protected-access
return gen_dataset_ops.parallel_interleave_dataset_v2(
self._input_dataset._as_variant_tensor(),
self._map_func.function.captured_inputs,
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, # pylint: disable=protected-access
self._cycle_length,
self._block_length,
self._num_parallel_calls,
f=self._map_func.function,
**flat_structure(self))
super(ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):
return self._structure
def _transformation_name(self):
return "Dataset.interleave()"
@ -2878,7 +2921,6 @@ class FilterDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, predicate):
"""See `Dataset.filter()` for details."""
super(FilterDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
predicate, self._transformation_name(), dataset=input_dataset)
@ -2886,16 +2928,15 @@ class FilterDataset(UnaryUnchangedStructureDataset):
structure_lib.TensorStructure(dtypes.bool, [])):
raise ValueError("`predicate` must return a scalar boolean tensor.")
self._predicate = wrapped_func
def _functions(self):
return [self._predicate]
def _as_variant_tensor(self):
return gen_dataset_ops.filter_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.filter_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
other_arguments=self._predicate.function.captured_inputs,
predicate=self._predicate.function,
**flat_structure(self))
super(FilterDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._predicate]
def _transformation_name(self):
return "Dataset.filter()"
@ -2906,18 +2947,16 @@ class PrefetchDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, buffer_size):
"""See `Dataset.prefetch()` for details."""
super(PrefetchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if buffer_size is None:
buffer_size = -1 # This is the sentinel for auto-tuning.
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
def _as_variant_tensor(self):
return gen_dataset_ops.prefetch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.prefetch_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
buffer_size=self._buffer_size,
**flat_structure(self))
super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
class WindowDataset(UnaryDataset):
@ -2925,7 +2964,6 @@ class WindowDataset(UnaryDataset):
def __init__(self, input_dataset, size, shift, stride, drop_remainder):
"""See `window_dataset()` for more details."""
super(WindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
@ -2944,15 +2982,14 @@ class WindowDataset(UnaryDataset):
nest.flatten(input_dataset.output_types))
])
self._structure = structure_lib.NestedStructure(nest_of_structures)
def _as_variant_tensor(self):
return gen_dataset_ops.window_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.window_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._size,
self._shift,
self._stride,
self._drop_remainder,
**flat_structure(self))
super(WindowDataset, self).__init__(input_dataset, variant_tensor)
@property
def _element_structure(self):
@ -2963,16 +3000,14 @@ class _OptionsDataset(UnaryUnchangedStructureDataset):
"""An identity `Dataset` that stores options."""
def __init__(self, input_dataset, options):
super(_OptionsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._options = input_dataset.options()
if self._options:
self._options = self._options.merge(options)
else:
self._options = options
def _as_variant_tensor(self):
return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
variant_tensor = input_dataset._variant_tensor # pylint: disable=protected-access
super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
def options(self):
return self._options
@ -2983,13 +3018,11 @@ class _ModelDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
"""See `optimize()` for details."""
super(_ModelDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
return gen_dataset_ops.model_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.model_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
**flat_structure(self))
super(_ModelDataset, self).__init__(input_dataset, variant_tensor)
class _OptimizeDataset(UnaryUnchangedStructureDataset):
@ -2997,68 +3030,63 @@ class _OptimizeDataset(UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
super(_OptimizeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
self._optimizations = ops.convert_to_tensor(
optimizations, dtype=dtypes.string, name="optimizations")
def _as_variant_tensor(self):
return gen_dataset_ops.optimize_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.optimize_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._optimizations,
**flat_structure(self))
super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, and sets a stats aggregator."""
def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = aggregator
self._prefix = prefix
self._counter_prefix = counter_prefix
def _as_variant_tensor(self):
return ged_ops.experimental_set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
self._prefix,
self._counter_prefix,
**flat_structure(self))
super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
variant_tensor)
class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, overriding intra-op parallelism."""
def __init__(self, input_dataset, max_intra_op_parallelism):
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._max_intra_op_parallelism = ops.convert_to_tensor(
max_intra_op_parallelism,
dtype=dtypes.int64,
name="max_intra_op_parallelism")
def _as_variant_tensor(self):
return ged_ops.experimental_max_intra_op_parallelism_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._max_intra_op_parallelism,
**flat_structure(self))
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
variant_tensor)
class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, setting a private threadpool."""
def __init__(self, input_dataset, num_threads):
super(_PrivateThreadPoolDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._num_threads = ops.convert_to_tensor(
num_threads, dtype=dtypes.int64, name="num_threads")
def _as_variant_tensor(self):
return ged_ops.experimental_private_thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
variant_tensor = ged_ops.experimental_private_thread_pool_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._num_threads,
**flat_structure(self))
super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
variant_tensor)

View File

@ -357,7 +357,7 @@ class Iterator(checkpointable.CheckpointableBase):
(self.output_shapes, dataset.output_shapes))
with ops.colocate_with(self._iterator_resource):
return gen_dataset_ops.make_iterator(
dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
def get_next(self, name=None):
"""Returns a nested structure of `tf.Tensor`s representing the next element.
@ -524,7 +524,7 @@ class EagerIterator(checkpointable.CheckpointableBase):
with ops.device("/cpu:0"):
# pylint: disable=protected-access
dataset = dataset._apply_options()
ds_variant = dataset._as_variant_tensor()
ds_variant = dataset._variant_tensor
self._structure = structure_lib.convert_legacy_structure(
dataset.output_types, dataset.output_shapes, dataset.output_classes)
self._flat_output_types = self._structure._flat_types

View File

@ -30,12 +30,11 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
class _PerDeviceGenerator(dataset_ops.Dataset):
class _PerDeviceGenerator(dataset_ops.DatasetV2):
"""A `dummy` generator dataset."""
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
source_device, target_device, element_structure):
super(_PerDeviceGenerator, self).__init__()
self._target_device = target_device
self._structure = element_structure
@ -108,9 +107,8 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
)
self._finalize_captured_args = self._finalize_func.captured_inputs
def _as_variant_tensor(self):
with ops.device(self._target_device):
return gen_dataset_ops.generator_dataset(
variant_tensor = gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
@ -118,6 +116,7 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
next_func=self._next_func,
finalize_func=self._finalize_func,
**dataset_ops.flat_structure(self))
super(_PerDeviceGenerator, self).__init__(variant_tensor)
def _inputs(self):
# TODO(b/116506223): Determine which datasets should be used as inputs here.
@ -177,7 +176,7 @@ class MultiDeviceIterator(object):
# The incarnation ID is used to ensure consistency between the per-device
# iterators and the multi-device iterator.
self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
self._dataset._as_variant_tensor(), # pylint: disable=protected-access
self._dataset._variant_tensor, # pylint: disable=protected-access
self._multi_device_iterator_resource,
max_buffer_size=max_buffer_size)
@ -200,7 +199,8 @@ class MultiDeviceIterator(object):
options.experimental_optimization.apply_default_optimizations = False
ds = ds.with_options(options)
with ops.device(device):
self._device_iterators.append(ds.make_initializable_iterator())
self._device_iterators.append(
dataset_ops.make_initializable_iterator(ds))
device_iterator_initializers = [
iterator.initializer for iterator in self._device_iterators

View File

@ -49,7 +49,6 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
to buffer. A value of 0 results in the default buffering values chosen
based on the compression type.
"""
super(TextLineDatasetV2, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
self._compression_type = convert.optional_param_to_tensor(
@ -59,10 +58,9 @@ class TextLineDatasetV2(dataset_ops.DatasetSource):
argument_dtype=dtypes.string)
self._buffer_size = convert.optional_param_to_tensor(
"buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
def _as_variant_tensor(self):
return gen_dataset_ops.text_line_dataset(
variant_tensor = gen_dataset_ops.text_line_dataset(
self._filenames, self._compression_type, self._buffer_size)
super(TextLineDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):
@ -100,7 +98,6 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
buffer_size: (Optional.) A `tf.int64` scalar representing the number of
bytes in the read buffer. 0 means no buffering.
"""
super(_TFRecordDataset, self).__init__()
# Force the type to string even if filenames is an empty list.
self._filenames = ops.convert_to_tensor(
filenames, dtypes.string, name="filenames")
@ -113,24 +110,32 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
"buffer_size",
buffer_size,
argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
def _as_variant_tensor(self):
return gen_dataset_ops.tf_record_dataset(
variant_tensor = gen_dataset_ops.tf_record_dataset(
self._filenames, self._compression_type, self._buffer_size)
super(_TFRecordDataset, self).__init__(variant_tensor)
@property
def _element_structure(self):
return structure.TensorStructure(dtypes.string, [])
class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements):
"""See `tf.data.experimental.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func,
cycle_length, block_length)
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure,
dataset_ops.DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
self._buffer_output_elements = convert.optional_param_to_tensor(
@ -141,11 +146,8 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
"prefetch_input_elements",
prefetch_input_elements,
argument_default=2 * cycle_length)
def _as_variant_tensor(self):
# pylint: disable=protected-access
return ged_ops.experimental_parallel_interleave_dataset(
self._input_dataset._as_variant_tensor(),
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
@ -154,7 +156,15 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
self._prefetch_input_elements,
f=self._map_func.function,
**dataset_ops.flat_structure(self))
# pylint: enable=protected-access
super(ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):
return self._structure
def _transformation_name(self):
return "tf.data.experimental.parallel_interleave()"
@ -186,7 +196,6 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
TypeError: If any argument does not have the expected type.
ValueError: If any argument does not have the expected shape.
"""
super(TFRecordDatasetV2, self).__init__()
if isinstance(filenames, dataset_ops.DatasetV2):
if filenames.output_types != dtypes.string:
raise TypeError(
@ -215,6 +224,8 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
filenames, read_one_file, cycle_length=num_parallel_reads,
block_length=1, sloppy=False, buffer_output_elements=None,
prefetch_input_elements=None)
variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access
super(TFRecordDatasetV2, self).__init__(variant_tensor)
def _clone(self,
filenames=None,
@ -226,9 +237,6 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
buffer_size or self._buffer_size,
num_parallel_reads or self._num_parallel_reads)
def _as_variant_tensor(self):
return self._impl._as_variant_tensor() # pylint: disable=protected-access
def _inputs(self):
return self._impl._inputs() # pylint: disable=protected-access
@ -295,7 +303,6 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
compression_type: (Optional.) A `tf.string` scalar evaluating to one of
`""` (no compression), `"ZLIB"`, or `"GZIP"`.
"""
super(FixedLengthRecordDatasetV2, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
self._record_bytes = ops.convert_to_tensor(
@ -312,17 +319,16 @@ class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
compression_type,
argument_default="",
argument_dtype=dtypes.string)
def _as_variant_tensor(self):
if (self._compression_type is not None or
compat.forward_compatible(2018, 11, 30)):
return gen_dataset_ops.fixed_length_record_dataset_v2(
variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size, self._compression_type)
else:
return gen_dataset_ops.fixed_length_record_dataset(
variant_tensor = gen_dataset_ops.fixed_length_record_dataset(
self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size)
super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
@property
def _element_structure(self):

View File

@ -163,3 +163,24 @@ py_test(
"//tensorflow/python:util",
],
)
py_library(
name = "traverse",
srcs = ["traverse.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "traverse_test",
size = "small",
srcs = ["traverse_test.py"],
srcs_version = "PY2AND3",
deps = [
":traverse",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/data/ops:dataset_ops",
],
)

View File

@ -0,0 +1,56 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers to traverse the Dataset dependency structure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
def obtain_all_variant_tensor_ops(dataset):
"""Given an input dataset, finds all dataset ops used for construction.
A series of transformations would have created this dataset with each
transformation including zero or more Dataset ops, each producing a dataset
variant tensor. This method outputs all of them.
Args:
dataset: Dataset to find variant tensors for.
Returns:
A list of variant_tensor producing dataset ops used to construct this
dataset.
"""
all_variant_tensor_ops = []
bfs_q = Queue.Queue()
bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access
visited = []
while not bfs_q.empty():
op = bfs_q.get()
visited.append(op)
# We look for all ops that produce variant tensors as output. This is a bit
# of overkill but the other dataset _inputs() traversal strategies can't
# cover the case of function inputs that capture dataset variants.
# TODO(b/120873778): Make this more efficient.
if op.outputs[0].dtype == dtypes.variant:
all_variant_tensor_ops.append(op)
for i in op.inputs:
input_op = i.op
if input_op not in visited:
bfs_q.put(input_op)
return all_variant_tensor_ops

View File

@ -0,0 +1,109 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities for traversing the dataset construction graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import traverse
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
self._input_dataset = input_dataset
temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
input_dataset._variant_tensor,
buffer_size=1,
**dataset_ops.flat_structure(self))
variant_tensor = gen_dataset_ops.model_dataset(
temp_variant_tensor, **dataset_ops.flat_structure(self))
super(_TestDataset, self).__init__(input_dataset, variant_tensor)
class TraverseTest(test.TestCase):
@test_util.run_deprecated_v1
def testOnlySource(self):
ds = dataset_ops.Dataset.range(10)
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
self.assertAllEqual(["RangeDataset"], [x.name for x in variant_tensor_ops])
@test_util.run_deprecated_v1
def testSimplePipeline(self):
ds = dataset_ops.Dataset.range(10).map(math_ops.square)
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
self.assertSetEqual(
set(["MapDataset", "RangeDataset"]),
set([x.name for x in variant_tensor_ops]))
@test_util.run_deprecated_v1
def testConcat(self):
ds1 = dataset_ops.Dataset.range(10)
ds2 = dataset_ops.Dataset.range(10)
ds = ds1.concatenate(ds2)
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
self.assertSetEqual(
set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]),
set([x.name for x in variant_tensor_ops]))
@test_util.run_deprecated_v1
def testZip(self):
ds1 = dataset_ops.Dataset.range(10)
ds2 = dataset_ops.Dataset.range(10)
ds = dataset_ops.Dataset.zip((ds1, ds2))
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
self.assertSetEqual(
set(["ZipDataset", "RangeDataset", "RangeDataset_1"]),
set([x.name for x in variant_tensor_ops]))
@test_util.run_deprecated_v1
def testMultipleVariantTensors(self):
ds = dataset_ops.Dataset.range(10)
ds = _TestDataset(ds)
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
self.assertSetEqual(
set(["RangeDataset", "ModelDataset", "PrefetchDataset"]),
set([x.name for x in variant_tensor_ops]))
@test_util.run_deprecated_v1
def testFlatMap(self):
ds1 = dataset_ops.Dataset.range(10).repeat(10)
def map_fn(ds):
def _map(x):
return ds.batch(x)
return _map
ds2 = dataset_ops.Dataset.range(20).prefetch(1)
ds2 = ds2.flat_map(map_fn(ds1))
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2)
self.assertSetEqual(
set([
"FlatMapDataset", "PrefetchDataset", "RepeatDataset",
"RangeDataset", "RangeDataset_1"
]), set([x.name for x in variant_tensor_ops]))
if __name__ == "__main__":
test.main()

View File

@ -270,6 +270,7 @@ cuda_py_test(
":input_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:traverse",
"//tensorflow/python:errors",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",

View File

@ -18,15 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import filter_for_shard_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import traverse
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging
# TODO(priyag): Any other reader datasets to consider here?
_READER_DATASET_OPS = [
"TextLineDataset", "TFRecordDataset", "FixedLengthRecordDataset",
@ -53,100 +51,57 @@ def auto_shard_dataset(dataset, num_shards, index):
determine a good way to shard the input dataset.
"""
# TODO(priyag): Clone datasets instead of updating in place, similar to the
# clone method for TFRecordDataset.
def _auto_shard_impl(dataset, found_reader_op):
"""Recursive implementation of auto sharding."""
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
tf_logging.warn("Autosharding is currently disabled. Please shard your input "
"manually.")
del num_shards, index
return dataset
if not found_reader_op:
# TODO(priyag): Make this check more robust by enforcing some common
# property on reader datasets.
if (isinstance(dataset, readers.TextLineDataset) or
isinstance(dataset, readers.FixedLengthRecordDataset)):
filenames_tensor = dataset._filenames
num_files = array_ops.size(filenames_tensor)
sharded_filenames_tensor = array_ops.gather(
filenames_tensor, math_ops.range(index, num_files, num_shards))
dataset._filenames = sharded_filenames_tensor
return dataset
elif isinstance(dataset, readers.TFRecordDataset):
# `TFRecordDataset` needs to be handled separately than other readers
# because it converts filenames to a dataset first. Also, we clone it
# instead of updating in place because it has special logic in the
# constructor. Eventually we will change all cases to clone datasets
# instead of updating in-place.
return dataset._clone(
filenames=dataset._filenames.apply(
filter_for_shard_ops.filter_for_shard(num_shards, index)))
elif isinstance(dataset, dataset_ops.RangeDataset):
return dataset.apply(
filter_for_shard_ops.filter_for_shard(num_shards, index))
elif hasattr(dataset, "_map_func"):
# TODO(priyag): Make this check more robust by enforcing some common
# property on all map/flatmap/interleave datasets.
map_func_def = dataset._map_func.function.definition
for node in map_func_def.node_def:
if node.op in _READER_DATASET_OPS:
found_reader_op = True
break
elif node.op == "FlatMapDataset":
# TODO(priyag): Should this check for other map datasets? Should it
# be recursive? It is too specific to implementation of
# TFRecordDataset right now.
nested_func_name = node.attr["f"].func.name
nested_func = ops.get_default_graph()._functions[nested_func_name]
for nested_node in nested_func.definition.node_def:
if nested_node.op in _READER_DATASET_OPS:
found_reader_op = True
break
if found_reader_op:
break
if found_reader_op:
dataset._input_dataset = _auto_shard_impl(
dataset._input_dataset, found_reader_op)
return dataset
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
dataset._dataset = _auto_shard_impl(
dataset._dataset, found_reader_op)
return dataset
def _clone_dataset(dataset):
"""Returns a cloned version of `dataset`."""
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(dataset)
remap_dict = _clone_helper(dataset._variant_tensor.op, variant_tensor_ops)
new_variant_tensor = remap_dict[dataset._variant_tensor.op].outputs[0]
return dataset_ops._VariantDataset(new_variant_tensor,
dataset._element_structure)
# TODO(priyag): Make _input_dataset(s) a common property of all datasets to
# make this check more robust.
if hasattr(dataset, "_input_dataset"):
dataset._input_dataset = _auto_shard_impl(
dataset._input_dataset, found_reader_op)
if hasattr(dataset, "_dataset_to_concatenate"):
# Special case for `ConcatentateDataset`. We want to shard all input
# datasets.
dataset._dataset_to_concatenate = _auto_shard_impl(
dataset._dataset_to_concatenate, found_reader_op)
return dataset
if hasattr(dataset, "_datasets"):
# Special case for `ZipDataset`.
dataset._datasets = nest.pack_sequence_as(dataset._datasets, [
_auto_shard_impl(ds, found_reader_op)
for ds in nest.flatten(dataset._datasets)
])
return dataset
def _get_op_def(op):
return op.op_def or op_def_registry.get_registered_ops()[op.type]
if not found_reader_op:
tf_logging.warn(
"Could not find a standard reader in the input pipeline"
"(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)."
"So auto-sharding is not done. Please verify correctness of "
"auto-sharding for your input.")
# TODO(yuefengz): maybe still shard it?
return dataset
# TODO(priyag): What do we want to do if the number of filenames is
# uneven in the number of shards? By default, this will just return as
# many items it can before throwing OutOfRangeError.
# TODO(priyag): This will shard the filenames before any shuffling of the
# filename dataset. It might be desirable to shard after shuffling
# filenames? If so, how do we achieve that?
return dataset.apply(
filter_for_shard_ops.filter_for_shard(num_shards, index))
def _clone_helper(op_to_clone, variant_tensor_ops):
"""Helper method that recursively clones `op_to_clone`.
return _auto_shard_impl(dataset=dataset, found_reader_op=False)
Args:
op_to_clone: The op we want to clone.
variant_tensor_ops: A list of ops that we have to clone along the way.
Returns:
A dictionary mapping old_ops to new_ops created. Includes op_to_clone
as a key.
"""
remap_dict = {}
for input_tensor in op_to_clone.inputs:
input_tensor_op = input_tensor.op
if input_tensor_op in variant_tensor_ops:
recursive_map = _clone_helper(input_tensor_op, variant_tensor_ops)
remap_dict.update(recursive_map)
inputs_list = []
for input_tensor in op_to_clone.inputs:
input_tensor_op = input_tensor.op
if input_tensor_op in remap_dict:
remapped_input = remap_dict[input_tensor_op].outputs[0]
inputs_list.append(remapped_input)
else:
inputs_list.append(input_tensor_op.outputs[input_tensor.value_index])
g = ops.get_default_graph()
new_op = g.create_op(
op_to_clone.type,
inputs_list, [o.dtype for o in op_to_clone.outputs],
name=op_to_clone.name,
attrs=op_to_clone.node_def.attr,
op_def=_get_op_def(op_to_clone))
remap_dict[op_to_clone] = new_op
return remap_dict

View File

@ -26,6 +26,8 @@ from tensorflow.python.distribute import input_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@ -90,7 +92,7 @@ class AutoShardDatasetTest(test.TestCase):
def _verifySimpleShardingOutput(self, dataset, record_fn):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
with self.cached_session():
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
@ -98,7 +100,7 @@ class AutoShardDatasetTest(test.TestCase):
self.evaluate(next_element)
@test_util.run_deprecated_v1
def testTFRecordDataset(self):
def DISABLED_testTFRecordDataset(self):
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
dataset = input_ops.auto_shard_dataset(
dataset, self._num_shards, self._shard_index)
@ -106,7 +108,7 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, self._record)
@test_util.run_deprecated_v1
def testFlatMap(self):
def DISABLED_testFlatMap(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
self._createTFRecordFiles())
dataset = dataset.flat_map(readers.TFRecordDataset)
@ -116,7 +118,7 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, self._record)
@test_util.run_deprecated_v1
def testInterleave(self):
def DISABLED_testInterleave(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
self._createTFRecordFiles())
dataset = dataset.interleave(
@ -129,7 +131,7 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, self._record)
@test_util.run_deprecated_v1
def testListfiles(self):
def DISABLED_testListfiles(self):
filenames = self._createTFRecordFiles()
file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
@ -139,7 +141,7 @@ class AutoShardDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
with self.cached_session():
actual, expected = [], []
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
@ -150,7 +152,7 @@ class AutoShardDatasetTest(test.TestCase):
self.assertAllEqual(expected, actual)
@test_util.run_deprecated_v1
def testComplexPipeline(self):
def DISABLED_testComplexPipeline(self):
# Setup a complex input pipeline.
batch_size = 2
num_epochs = 5
@ -172,7 +174,7 @@ class AutoShardDatasetTest(test.TestCase):
# Verify output.
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
with self.cached_session():
actual = []
num_iterations = (self._num_files * self._num_records * num_epochs) // (
self._num_shards * batch_size)
@ -190,7 +192,7 @@ class AutoShardDatasetTest(test.TestCase):
self.assertAllEqual(sorted(expected), sorted(actual))
@test_util.run_deprecated_v1
def testZip(self):
def DISABLED_testZip(self):
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
dataset2 = readers.TextLineDataset(self._createTextFiles())
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
@ -201,7 +203,7 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, record_fn)
@test_util.run_deprecated_v1
def testConcat(self):
def DISABLED_testConcat(self):
dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
dataset2 = readers.TextLineDataset(self._createTextFiles())
dataset = dataset1.concatenate(dataset2)
@ -222,7 +224,7 @@ class AutoShardDatasetTest(test.TestCase):
self.evaluate(next_element)
@test_util.run_deprecated_v1
def testTextLineReader(self):
def DISABLED_testTextLineReader(self):
dataset = readers.TextLineDataset(self._createTextFiles())
dataset = input_ops.auto_shard_dataset(
dataset, self._num_shards, self._shard_index)
@ -230,7 +232,7 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, self._text_line)
@test_util.run_deprecated_v1
def testTextLineReaderWithFlatMap(self):
def DISABLED_testTextLineReaderWithFlatMap(self):
dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
dataset = dataset.flat_map(readers.TextLineDataset)
dataset = input_ops.auto_shard_dataset(
@ -239,7 +241,7 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, self._text_line)
@test_util.run_deprecated_v1
def testFixedLengthReader(self):
def DISABLED_testFixedLengthReader(self):
dataset = readers.FixedLengthRecordDataset(
self._createFixedLengthRecordFiles(), self._record_bytes)
dataset = input_ops.auto_shard_dataset(
@ -248,7 +250,7 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
@test_util.run_deprecated_v1
def testFixedLengthReaderWithFlatMap(self):
def DISABLED_testFixedLengthReaderWithFlatMap(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
self._createFixedLengthRecordFiles())
dataset = dataset.flat_map(
@ -258,5 +260,77 @@ class AutoShardDatasetTest(test.TestCase):
self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
# A dataset that creates two variant tensors.
class _TestDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
self._input_dataset = input_dataset
temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
input_dataset._variant_tensor,
buffer_size=1,
**dataset_ops.flat_structure(self))
variant_tensor = gen_dataset_ops.model_dataset(
temp_variant_tensor, **dataset_ops.flat_structure(self))
super(_TestDataset, self).__init__(input_dataset, variant_tensor)
class CloneDatasetTest(test.TestCase):
def _assert_datasets_equal(self, ds1, ds2):
# First lets assert the structure is the same.
self.assertTrue(
ds1._element_structure.is_compatible_with(ds2._element_structure))
self.assertTrue(
ds2._element_structure.is_compatible_with(ds1._element_structure))
# Now create iterators on both and assert they produce the same values.
it1 = dataset_ops.make_initializable_iterator(ds1)
it2 = dataset_ops.make_initializable_iterator(ds2)
get_next1 = it1.get_next()
get_next2 = it2.get_next()
with self.cached_session():
self.evaluate([it1.initializer, it2.initializer])
val1, val2 = self.evaluate([get_next1, get_next2])
self.assertEqual(val1, val2)
@test_util.run_deprecated_v1
def testOnlySource(self):
ds = dataset_ops.Dataset.range(10)
cloned_ds = input_ops._clone_dataset(ds)
self._assert_datasets_equal(ds, cloned_ds)
@test_util.run_deprecated_v1
def testSimplePipeline(self):
ds = dataset_ops.Dataset.range(10).map(math_ops.square)
cloned_ds = input_ops._clone_dataset(ds)
self._assert_datasets_equal(ds, cloned_ds)
@test_util.run_deprecated_v1
def testConcat(self):
ds1 = dataset_ops.Dataset.range(10)
ds2 = dataset_ops.Dataset.range(10)
ds = ds1.concatenate(ds2)
cloned_ds = input_ops._clone_dataset(ds)
self._assert_datasets_equal(ds, cloned_ds)
@test_util.run_deprecated_v1
def testZip(self):
ds1 = dataset_ops.Dataset.range(10)
ds2 = dataset_ops.Dataset.range(10)
ds = dataset_ops.Dataset.zip((ds1, ds2))
cloned_ds = input_ops._clone_dataset(ds)
self._assert_datasets_equal(ds, cloned_ds)
@test_util.run_deprecated_v1
def testMultipleVariantTensors(self):
ds = dataset_ops.Dataset.range(10)
ds = _TestDataset(ds)
cloned_ds = input_ops._clone_dataset(ds)
self._assert_datasets_equal(ds, cloned_ds)
if __name__ == "__main__":
test.main()

View File

@ -1600,9 +1600,9 @@ class MultiWorkerDataset(object):
if len(dataset_fn) != input_workers.num_workers:
raise ValueError("If `dataset_fn` is a list, it must have one entry "
"per worker")
if auto_shard:
raise ValueError(
"If `dataset_fn` is a list, `auto_shard` is not supported.")
# TODO(rohanj): b/120673685 to track re-enabling auto sharding.
if auto_shard:
raise ValueError("Currently autosharding is not supported.")
self._input_workers = input_workers
self._datasets = []
# TODO(yuefengz, priyag): support different set of jobs for input
@ -1613,9 +1613,6 @@ class MultiWorkerDataset(object):
worker_input = dataset_fn[i]()
else:
worker_input = dataset_fn()
if auto_shard:
worker_input = input_ops.auto_shard_dataset(
worker_input, input_workers.num_workers, i)
dataset = PerReplicaDataset(worker_input, input_workers, i,
prefetch_on_device=prefetch_on_device)
self._datasets.append((worker, dataset))
@ -1805,7 +1802,11 @@ class DatasetIterator(InputIteratorImpl):
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(dataset, worker, worker_devices)
cloned_dataset = dataset
if not context.executing_eagerly():
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
worker_devices)
iterators.append(iterator)
super(DatasetIterator, self).__init__(input_workers, iterators)

View File

@ -16,7 +16,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'variant_tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "apply"