[tf.data] Bug fix: make_csv_dataset should not modify mutable parameters passed into it

Fixes #39186.

PiperOrigin-RevId: 310462826
Change-Id: I8b8dd6f16f9fab6b1e02410dc6e8c91f748772f6
This commit is contained in:
Rachel Lim 2020-05-07 16:51:02 -07:00 committed by TensorFlower Gardener
parent 797ac7bf87
commit 79acb0824b
2 changed files with 22 additions and 9 deletions

View File

@ -41,9 +41,9 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
for i, ip in enumerate(inputs):
for i, file_rows in enumerate(inputs):
fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
contents = linebreak.join(ip).encode('utf-8')
contents = linebreak.join(file_rows).encode('utf-8')
if compression_type is None:
with open(fn, 'wb') as f:
f.write(contents)
@ -580,6 +580,13 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
record_defaults=record_defaults)
def testCsvDataset_immutableParams(self):
inputs = [['a,b,c', '1,2,3', '4,5,6']]
filenames = self._setup_files(inputs)
select_cols = ['a', 'c']
_ = readers.make_csv_dataset(
filenames, batch_size=1, select_columns=select_cols)
self.assertAllEqual(select_cols, ['a', 'c'])
if __name__ == '__main__':
test.main()

View File

@ -183,24 +183,30 @@ def _get_sorted_col_indices(select_columns, column_names):
"""Transforms select_columns argument into sorted column indices."""
names_to_indices = {n: i for i, n in enumerate(column_names)}
num_cols = len(column_names)
for i, v in enumerate(select_columns):
results = []
for v in select_columns:
# If value is already an int, check if it's valid.
if isinstance(v, int):
if v < 0 or v >= num_cols:
raise ValueError(
"Column index %d specified in select_columns out of valid range." %
v)
continue
if v not in names_to_indices:
results.append(v)
# Otherwise, check that it's a valid column name and convert to the
# the relevant column index.
elif v not in names_to_indices:
raise ValueError(
"Value '%s' specified in select_columns not a valid column index or "
"name." % v)
select_columns[i] = names_to_indices[v]
else:
results.append(names_to_indices[v])
# Sort and ensure there are no duplicates
result = sorted(set(select_columns))
if len(result) != len(select_columns):
results = sorted(set(results))
if len(results) != len(select_columns):
raise ValueError("select_columns contains duplicate columns")
return result
return results
def _maybe_shuffle_and_repeat(