diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py index 941ca209848..13948305aea 100644 --- a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py @@ -41,9 +41,9 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] - for i, ip in enumerate(inputs): + for i, file_rows in enumerate(inputs): fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) - contents = linebreak.join(ip).encode('utf-8') + contents = linebreak.join(file_rows).encode('utf-8') if compression_type is None: with open(fn, 'wb') as f: f.write(contents) @@ -580,6 +580,13 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], record_defaults=record_defaults) + def testCsvDataset_immutableParams(self): + inputs = [['a,b,c', '1,2,3', '4,5,6']] + filenames = self._setup_files(inputs) + select_cols = ['a', 'c'] + _ = readers.make_csv_dataset( + filenames, batch_size=1, select_columns=select_cols) + self.assertAllEqual(select_cols, ['a', 'c']) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py index 8795a206bb1..b8f4c34f40e 100644 --- a/tensorflow/python/data/experimental/ops/readers.py +++ b/tensorflow/python/data/experimental/ops/readers.py @@ -183,24 +183,30 @@ def _get_sorted_col_indices(select_columns, column_names): """Transforms select_columns argument into sorted column indices.""" names_to_indices = {n: i for i, n in enumerate(column_names)} num_cols = len(column_names) - for i, v in enumerate(select_columns): + + results = [] + for v in select_columns: + # If value is already an int, check if it's valid. if isinstance(v, int): if v < 0 or v >= num_cols: raise ValueError( "Column index %d specified in select_columns out of valid range." % v) - continue - if v not in names_to_indices: + results.append(v) + # Otherwise, check that it's a valid column name and convert to the + # the relevant column index. + elif v not in names_to_indices: raise ValueError( "Value '%s' specified in select_columns not a valid column index or " "name." % v) - select_columns[i] = names_to_indices[v] + else: + results.append(names_to_indices[v]) # Sort and ensure there are no duplicates - result = sorted(set(select_columns)) - if len(result) != len(select_columns): + results = sorted(set(results)) + if len(results) != len(select_columns): raise ValueError("select_columns contains duplicate columns") - return result + return results def _maybe_shuffle_and_repeat(