[tf.data] Bug fix: make_csv_dataset should not modify mutable parameters passed into it
Fixes #39186. PiperOrigin-RevId: 310462826 Change-Id: I8b8dd6f16f9fab6b1e02410dc6e8c91f748772f6
This commit is contained in:
parent
797ac7bf87
commit
79acb0824b
@ -41,9 +41,9 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
|
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
|
||||||
filenames = []
|
filenames = []
|
||||||
for i, ip in enumerate(inputs):
|
for i, file_rows in enumerate(inputs):
|
||||||
fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
|
fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
|
||||||
contents = linebreak.join(ip).encode('utf-8')
|
contents = linebreak.join(file_rows).encode('utf-8')
|
||||||
if compression_type is None:
|
if compression_type is None:
|
||||||
with open(fn, 'wb') as f:
|
with open(fn, 'wb') as f:
|
||||||
f.write(contents)
|
f.write(contents)
|
||||||
@ -580,6 +580,13 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
|
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
|
||||||
record_defaults=record_defaults)
|
record_defaults=record_defaults)
|
||||||
|
|
||||||
|
def testCsvDataset_immutableParams(self):
|
||||||
|
inputs = [['a,b,c', '1,2,3', '4,5,6']]
|
||||||
|
filenames = self._setup_files(inputs)
|
||||||
|
select_cols = ['a', 'c']
|
||||||
|
_ = readers.make_csv_dataset(
|
||||||
|
filenames, batch_size=1, select_columns=select_cols)
|
||||||
|
self.assertAllEqual(select_cols, ['a', 'c'])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -183,24 +183,30 @@ def _get_sorted_col_indices(select_columns, column_names):
|
|||||||
"""Transforms select_columns argument into sorted column indices."""
|
"""Transforms select_columns argument into sorted column indices."""
|
||||||
names_to_indices = {n: i for i, n in enumerate(column_names)}
|
names_to_indices = {n: i for i, n in enumerate(column_names)}
|
||||||
num_cols = len(column_names)
|
num_cols = len(column_names)
|
||||||
for i, v in enumerate(select_columns):
|
|
||||||
|
results = []
|
||||||
|
for v in select_columns:
|
||||||
|
# If value is already an int, check if it's valid.
|
||||||
if isinstance(v, int):
|
if isinstance(v, int):
|
||||||
if v < 0 or v >= num_cols:
|
if v < 0 or v >= num_cols:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Column index %d specified in select_columns out of valid range." %
|
"Column index %d specified in select_columns out of valid range." %
|
||||||
v)
|
v)
|
||||||
continue
|
results.append(v)
|
||||||
if v not in names_to_indices:
|
# Otherwise, check that it's a valid column name and convert to the
|
||||||
|
# the relevant column index.
|
||||||
|
elif v not in names_to_indices:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Value '%s' specified in select_columns not a valid column index or "
|
"Value '%s' specified in select_columns not a valid column index or "
|
||||||
"name." % v)
|
"name." % v)
|
||||||
select_columns[i] = names_to_indices[v]
|
else:
|
||||||
|
results.append(names_to_indices[v])
|
||||||
|
|
||||||
# Sort and ensure there are no duplicates
|
# Sort and ensure there are no duplicates
|
||||||
result = sorted(set(select_columns))
|
results = sorted(set(results))
|
||||||
if len(result) != len(select_columns):
|
if len(results) != len(select_columns):
|
||||||
raise ValueError("select_columns contains duplicate columns")
|
raise ValueError("select_columns contains duplicate columns")
|
||||||
return result
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _maybe_shuffle_and_repeat(
|
def _maybe_shuffle_and_repeat(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user