[tf.data] Bug fix: make_csv_dataset should not modify mutable parameters passed into it
Fixes #39186. PiperOrigin-RevId: 310462826 Change-Id: I8b8dd6f16f9fab6b1e02410dc6e8c91f748772f6
This commit is contained in:
parent
797ac7bf87
commit
79acb0824b
@ -41,9 +41,9 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
|
||||
filenames = []
|
||||
for i, ip in enumerate(inputs):
|
||||
for i, file_rows in enumerate(inputs):
|
||||
fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
|
||||
contents = linebreak.join(ip).encode('utf-8')
|
||||
contents = linebreak.join(file_rows).encode('utf-8')
|
||||
if compression_type is None:
|
||||
with open(fn, 'wb') as f:
|
||||
f.write(contents)
|
||||
@ -580,6 +580,13 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
|
||||
record_defaults=record_defaults)
|
||||
|
||||
def testCsvDataset_immutableParams(self):
|
||||
inputs = [['a,b,c', '1,2,3', '4,5,6']]
|
||||
filenames = self._setup_files(inputs)
|
||||
select_cols = ['a', 'c']
|
||||
_ = readers.make_csv_dataset(
|
||||
filenames, batch_size=1, select_columns=select_cols)
|
||||
self.assertAllEqual(select_cols, ['a', 'c'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -183,24 +183,30 @@ def _get_sorted_col_indices(select_columns, column_names):
|
||||
"""Transforms select_columns argument into sorted column indices."""
|
||||
names_to_indices = {n: i for i, n in enumerate(column_names)}
|
||||
num_cols = len(column_names)
|
||||
for i, v in enumerate(select_columns):
|
||||
|
||||
results = []
|
||||
for v in select_columns:
|
||||
# If value is already an int, check if it's valid.
|
||||
if isinstance(v, int):
|
||||
if v < 0 or v >= num_cols:
|
||||
raise ValueError(
|
||||
"Column index %d specified in select_columns out of valid range." %
|
||||
v)
|
||||
continue
|
||||
if v not in names_to_indices:
|
||||
results.append(v)
|
||||
# Otherwise, check that it's a valid column name and convert to the
|
||||
# the relevant column index.
|
||||
elif v not in names_to_indices:
|
||||
raise ValueError(
|
||||
"Value '%s' specified in select_columns not a valid column index or "
|
||||
"name." % v)
|
||||
select_columns[i] = names_to_indices[v]
|
||||
else:
|
||||
results.append(names_to_indices[v])
|
||||
|
||||
# Sort and ensure there are no duplicates
|
||||
result = sorted(set(select_columns))
|
||||
if len(result) != len(select_columns):
|
||||
results = sorted(set(results))
|
||||
if len(results) != len(select_columns):
|
||||
raise ValueError("select_columns contains duplicate columns")
|
||||
return result
|
||||
return results
|
||||
|
||||
|
||||
def _maybe_shuffle_and_repeat(
|
||||
|
Loading…
x
Reference in New Issue
Block a user