diff --git a/tensorflow/contrib/learn/python/learn/dataframe/dataframe.py b/tensorflow/contrib/learn/python/learn/dataframe/dataframe.py index 31093b9937a..6e03f086425 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/dataframe.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/dataframe.py @@ -117,10 +117,11 @@ class DataFrame(object): value = [value] self.assign(**dict(zip(key, value))) - def build(self): + def build(self, **kwargs): # We do not allow passing a cache here, because that would encourage # working around the rule that DataFrames cannot be expected to be # synced with each other (e.g., they shuffle independently). cache = {} - tensors = {name: c.build(cache) for name, c in self._columns.items()} + tensors = {name: c.build(cache, **kwargs) + for name, c in self._columns.items()} return tensors diff --git a/tensorflow/contrib/learn/python/learn/dataframe/estimator_utils.py b/tensorflow/contrib/learn/python/learn/dataframe/estimator_utils.py index bff0c4e4af0..313ae41cfe8 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/estimator_utils.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/estimator_utils.py @@ -91,7 +91,8 @@ def _build_alternate_universe( def to_feature_columns_and_input_fn(dataframe, base_input_keys_with_defaults, feature_keys, - target_keys=None): + target_keys=None, + **kwargs): """Build a list of FeatureColumns and an input_fn for use with Estimator. Args: @@ -103,6 +104,7 @@ def to_feature_columns_and_input_fn(dataframe, These may include base features and/or derived features. target_keys: the names of columns to be used as targets. None is acceptable for unsupervised learning. + **kwargs: Additional keyword arguments, unused here. Returns: A tuple of two elements: @@ -155,10 +157,11 @@ def to_feature_columns_and_input_fn(dataframe, # Build an input_fn suitable for use with Estimator. def input_fn(): + """An input_fn() for feeding the given set of DataFrameColumns.""" # It's important to build all the tensors together in one DataFrame. # If we did df.select() for both key sets and then build those, the two # resulting DataFrames would be shuffled independently. - tensors = limited_dataframe.build() + tensors = limited_dataframe.build(**kwargs) base_input_features = {key: tensors[key] for key in base_input_keys} targets = {key: tensors[key] for key in target_keys} diff --git a/tensorflow/contrib/learn/python/learn/dataframe/series.py b/tensorflow/contrib/learn/python/learn/dataframe/series.py index 12daa7d7cb8..5893db3aad2 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/series.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/series.py @@ -98,7 +98,7 @@ class Series(object): return transform_cls return register - def build(self, cache): + def build(self, cache, **kwargs): """Returns a Tensor.""" raise NotImplementedError() @@ -122,7 +122,7 @@ class PredefinedSeries(Series): def required_base_features(self): return {self.name: self.feature_spec} - def build(self, cache): + def build(self, cache, **kwargs): try: return cache[self.name] except KeyError: @@ -171,10 +171,11 @@ class TransformedSeries(Series): result.update(s.required_base_features) return result - def build(self, cache=None): + def build(self, cache=None, **kwargs): if cache is None: cache = {} - all_outputs = self._transform.build_transitive(self._input_series, cache) + all_outputs = self._transform.build_transitive( + self._input_series, cache, **kwargs) return getattr(all_outputs, self._output_name) def __repr__(self): diff --git a/tensorflow/contrib/learn/python/learn/dataframe/tensorflow_dataframe.py b/tensorflow/contrib/learn/python/learn/dataframe/tensorflow_dataframe.py index 45df3ac16d5..4b6091dc16c 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/tensorflow_dataframe.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/tensorflow_dataframe.py @@ -83,7 +83,8 @@ class TensorFlowDataFrame(df.DataFrame): graph=None, session=None, start_queues=True, - initialize_variables=True): + initialize_variables=True, + **kwargs): """Builds and runs the columns of the `DataFrame` and yields batches. This is a generator that yields a dictionary mapping column names to @@ -97,6 +98,7 @@ class TensorFlowDataFrame(df.DataFrame): start_queues: if true, queues will be started before running and halted after producting `n` batches. initialize_variables: if true, variables will be initialized. + **kwargs: Additional keyword arguments, unused here. Yields: A dictionary, mapping column names to the values resulting from running @@ -107,7 +109,7 @@ class TensorFlowDataFrame(df.DataFrame): with graph.as_default(): if session is None: session = sess.Session() - self_built = self.build() + self_built = self.build(**kwargs) keys = list(self_built.keys()) cols = list(self_built.values()) if initialize_variables: @@ -208,7 +210,7 @@ class TensorFlowDataFrame(df.DataFrame): @classmethod def _from_csv_base(cls, filepatterns, get_default_values, has_header, - column_names, num_epochs, num_threads, enqueue_size, + column_names, num_threads, enqueue_size, batch_size, queue_capacity, min_after_dequeue, shuffle, seed): """Create a `DataFrame` from CSV files. @@ -223,9 +225,6 @@ class TensorFlowDataFrame(df.DataFrame): each column, given the column names. has_header: whether or not the CSV files have headers. column_names: a list of names for the columns in the CSV files. - num_epochs: the number of times that the reader should loop through all - the file names. If set to `None`, then the reader will continue - indefinitely. num_threads: the number of readers that will work in parallel. enqueue_size: block size for each read operation. batch_size: desired batch size. @@ -265,7 +264,6 @@ class TensorFlowDataFrame(df.DataFrame): reader_kwargs=reader_kwargs, enqueue_size=enqueue_size, batch_size=batch_size, - num_epochs=num_epochs, queue_capacity=queue_capacity, shuffle=shuffle, min_after_dequeue=min_after_dequeue, @@ -287,7 +285,6 @@ class TensorFlowDataFrame(df.DataFrame): default_values, has_header=True, column_names=None, - num_epochs=None, num_threads=1, enqueue_size=None, batch_size=32, @@ -306,9 +303,6 @@ class TensorFlowDataFrame(df.DataFrame): default_values: a list of default values for each column. has_header: whether or not the CSV files have headers. column_names: a list of names for the columns in the CSV files. - num_epochs: the number of times that the reader should loop through all - the file names. If set to `None`, then the reader will continue - indefinitely. num_threads: the number of readers that will work in parallel. enqueue_size: block size for each read operation. batch_size: desired batch size. @@ -332,7 +326,7 @@ class TensorFlowDataFrame(df.DataFrame): return default_values return cls._from_csv_base(filepatterns, get_default_values, has_header, - column_names, num_epochs, num_threads, + column_names, num_threads, enqueue_size, batch_size, queue_capacity, min_after_dequeue, shuffle, seed) @@ -342,7 +336,6 @@ class TensorFlowDataFrame(df.DataFrame): feature_spec, has_header=True, column_names=None, - num_epochs=None, num_threads=1, enqueue_size=None, batch_size=32, @@ -362,9 +355,6 @@ class TensorFlowDataFrame(df.DataFrame): `VarLenFeature`. has_header: whether or not the CSV files have headers. column_names: a list of names for the columns in the CSV files. - num_epochs: the number of times that the reader should loop through all - the file names. If set to `None`, then the reader will continue - indefinitely. num_threads: the number of readers that will work in parallel. enqueue_size: block size for each read operation. batch_size: desired batch size. @@ -387,7 +377,7 @@ class TensorFlowDataFrame(df.DataFrame): return [_get_default_value(feature_spec[name]) for name in column_names] dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header, - column_names, num_epochs, num_threads, + column_names, num_threads, enqueue_size, batch_size, queue_capacity, min_after_dequeue, shuffle, seed) @@ -405,7 +395,6 @@ class TensorFlowDataFrame(df.DataFrame): filepatterns, features, reader_cls=io_ops.TFRecordReader, - num_epochs=None, num_threads=1, enqueue_size=None, batch_size=32, @@ -421,9 +410,6 @@ class TensorFlowDataFrame(df.DataFrame): `FixedLenFeature`. reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to read the `Example`s. - num_epochs: the number of times that the reader should loop through all - the file names. If set to `None`, then the reader will continue - indefinitely. num_threads: the number of readers that will work in parallel. enqueue_size: block size for each read operation. batch_size: desired batch size. @@ -454,7 +440,6 @@ class TensorFlowDataFrame(df.DataFrame): filenames, enqueue_size=enqueue_size, batch_size=batch_size, - num_epochs=num_epochs, queue_capacity=queue_capacity, shuffle=shuffle, min_after_dequeue=min_after_dequeue, diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transform.py b/tensorflow/contrib/learn/python/learn/dataframe/transform.py index 745d556f929..bbb97d2f290 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transform.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transform.py @@ -223,13 +223,14 @@ class Transform(object): # pylint: disable=not-callable return self.return_type(*output_series) - def build_transitive(self, input_series, cache=None): + def build_transitive(self, input_series, cache=None, **kwargs): """Apply this `Transform` to the provided `Series`, producing 'Tensor's. Args: input_series: None, a `Series`, or a list of input `Series`, acting as positional arguments. cache: a dict from Series reprs to Tensors. + **kwargs: Additional keyword arguments, unused here. Returns: A namedtuple of the output Tensors. @@ -244,7 +245,7 @@ class Transform(object): if len(input_series) != self.input_valency: raise ValueError("Expected %s input Series but received %s." % (self.input_valency, len(input_series))) - input_tensors = [series.build(cache) for series in input_series] + input_tensors = [series.build(cache, **kwargs) for series in input_series] # Note we cache each output individually, not just the entire output # tuple. This allows using the graph as the cache, since it can sensibly @@ -254,7 +255,7 @@ class Transform(object): output_tensors = [cache.get(output_repr) for output_repr in output_reprs] if None in output_tensors: - result = self._apply_transform(input_tensors) + result = self._apply_transform(input_tensors, **kwargs) for output_name, output_repr in zip(self.output_names, output_reprs): cache[output_repr] = getattr(result, output_name) else: @@ -264,12 +265,13 @@ class Transform(object): return result @abstractmethod - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): """Applies the transformation to the `transform_input`. Args: - input_tensors: a list of Tensors representing the input to + input_tensors: a list of Tensors representing the input to the Transform. + **kwargs: Additional keyword arguments, unused here. Returns: A namedtuple of Tensors representing the transformed output. diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/batch.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/batch.py index 352a028ee33..cf1585634ca 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/batch.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/batch.py @@ -72,7 +72,7 @@ class Batch(AbstractBatchTransform): def name(self): return "Batch" - def _apply_transform(self, transform_input): + def _apply_transform(self, transform_input, **kwargs): batched = input_ops.batch(transform_input, batch_size=self.batch_size, num_threads=self.num_threads, @@ -121,7 +121,7 @@ class ShuffleBatch(AbstractBatchTransform): def seed(self): return self._seed - def _apply_transform(self, transform_input): + def _apply_transform(self, transform_input, **kwargs): batched = input_ops.shuffle_batch(transform_input, batch_size=self.batch_size, capacity=self.queue_capacity, diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py index 7d46fb6d05e..78a21250c9c 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google Inc. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ class SeriesBinaryTransform(transform.Transform): def _output_names(self): return "output", - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): # TODO(jamieas): consider supporting sparse inputs. if isinstance(input_tensors[0], ops.SparseTensor) or isinstance( input_tensors[1], ops.SparseTensor): @@ -87,7 +87,7 @@ class ScalarBinaryTransform(transform.Transform): def _output_names(self): return "output", - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): input_tensor = input_tensors[0] if isinstance(input_tensor, ops.SparseTensor): result = ops.SparseTensor(input_tensor.indices, diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/boolean_mask.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/boolean_mask.py index f572cf137f7..758de866e21 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/boolean_mask.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/boolean_mask.py @@ -77,12 +77,13 @@ class BooleanMask(transform.Transform): def _output_names(self): return "output", - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): """Applies the transformation to the `transform_input`. Args: - input_tensors: a list of Tensors representing the input to + input_tensors: a list of Tensors representing the input to the Transform. + **kwargs: Additional keyword arguments, unused here. Returns: A namedtuple of Tensors representing the transformed output. diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/csv_parser.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/csv_parser.py index caa83f5a966..d78b5652d6e 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/csv_parser.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/csv_parser.py @@ -58,7 +58,7 @@ class CSVParser(transform.Transform): def default_values(self): return self._default_values - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): default_consts = [constant_op.constant(d, shape=[1]) for d in self._default_values] parsed_values = parsing_ops.decode_csv(input_tensors[0], diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py index 2f389153178..0f0c1a08911 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py @@ -47,12 +47,13 @@ class Densify(transform.Transform): def _output_names(self): return "output", - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): """Applies the transformation to the `transform_input`. Args: - input_tensors: a list of Tensors representing the input to + input_tensors: a list of Tensors representing the input to the Transform. + **kwargs: Additional keyword arguments, unused here. Returns: A namedtuple of Tensors representing the transformed output. diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/difference.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/difference.py index d4e6c10094b..b585fceeb63 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/difference.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/difference.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google Inc. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ class Difference(transform.Transform): def _output_names(self): return "output", - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor), isinstance(input_tensors[1], ops.SparseTensor)) diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/example_parser.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/example_parser.py index e22ef740ed9..c2c5e0cbed5 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/example_parser.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/example_parser.py @@ -61,7 +61,7 @@ class ExampleParser(transform.Transform): def feature_definitions(self): return self._ordered_features - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): parsed_values = parsing_ops.parse_example(input_tensors[0], features=self._ordered_features) # pylint: disable=not-callable diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/in_memory_source.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/in_memory_source.py index 97453c30325..d96d53468a5 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/in_memory_source.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/in_memory_source.py @@ -89,7 +89,7 @@ class BaseInMemorySource(transform.Transform): def input_valency(self): return 0 - def _apply_transform(self, transform_input): + def _apply_transform(self, transform_input, **kwargs): queue = feeding_functions.enqueue_data(self.data, self.queue_capacity, self.shuffle, diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/reader_source.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/reader_source.py index 23556c40657..ddb2d321d1c 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/reader_source.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/reader_source.py @@ -32,7 +32,6 @@ class ReaderSource(transform.Transform): reader_kwargs=None, enqueue_size=None, batch_size=1, - num_epochs=None, queue_capacity=None, shuffle=False, min_after_dequeue=None, @@ -49,9 +48,6 @@ class ReaderSource(transform.Transform): is constructed. enqueue_size: block size for each read operation. batch_size: The desired batch size of output. Defaults to 1. - num_epochs: the number of times that the reader should loop through all - the file names. If set to `None`, then the reader will continue - indefinitely. queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`. shuffle: Whether records will be shuffled before returning. Defaults to false. @@ -73,7 +69,6 @@ class ReaderSource(transform.Transform): self._batch_size = batch_size self._queue_capacity = (batch_size * 10 if queue_capacity is None else queue_capacity) - self._num_epochs = num_epochs self._shuffle = shuffle self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue is None else min_after_dequeue) @@ -100,10 +95,6 @@ class ReaderSource(transform.Transform): def batch_size(self): return self._batch_size - @transform.parameter - def num_epochs(self): - return self._num_epochs - @transform.parameter def queue_capacity(self): return self._queue_capacity @@ -136,11 +127,12 @@ class ReaderSource(transform.Transform): def _output_names(self): return ("index", "value") - def _apply_transform(self, transform_input): - filename_queue = input_ops.string_input_producer(self.work_units, - num_epochs=self.num_epochs, - shuffle=self.shuffle, - seed=self.seed) + def _apply_transform(self, transform_input, **kwargs): + filename_queue = input_ops.string_input_producer( + self.work_units, + num_epochs=kwargs.get("num_epochs"), + shuffle=self.shuffle, + seed=self.seed) reader_ops = [] for _ in range(self.num_threads): reader = self._reader_cls(**self._reader_kwargs) @@ -174,7 +166,6 @@ def TextFileSource(file_names, reader_kwargs=None, enqueue_size=1, batch_size=1, - num_epochs=None, queue_capacity=None, shuffle=False, min_after_dequeue=None, @@ -185,7 +176,6 @@ def TextFileSource(file_names, reader_kwargs=reader_kwargs, enqueue_size=enqueue_size, batch_size=batch_size, - num_epochs=num_epochs, queue_capacity=queue_capacity, shuffle=shuffle, min_after_dequeue=min_after_dequeue, @@ -197,7 +187,6 @@ def TFRecordSource(file_names, reader_kwargs=None, enqueue_size=1, batch_size=1, - num_epochs=None, queue_capacity=None, shuffle=False, min_after_dequeue=None, @@ -208,7 +197,6 @@ def TFRecordSource(file_names, reader_kwargs=reader_kwargs, enqueue_size=enqueue_size, batch_size=batch_size, - num_epochs=num_epochs, queue_capacity=queue_capacity, shuffle=shuffle, min_after_dequeue=min_after_dequeue, diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/sparsify.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/sparsify.py index 552012ea330..f3447c5d940 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/sparsify.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/sparsify.py @@ -52,12 +52,13 @@ class Sparsify(transform.Transform): def _output_names(self): return "output", - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): """Applies the transformation to the `transform_input`. Args: - input_tensors: a list of Tensors representing the input to + input_tensors: a list of Tensors representing the input to the Transform. + **kwargs: Additional keyword arguments, unused here. Returns: A namedtuple of Tensors representing the transformed output. diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/sum.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/sum.py index 6b04166e09c..878b08f4b0a 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/sum.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/sum.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google Inc. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ class Sum(transform.Transform): def _output_names(self): return "output", - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor), isinstance(input_tensors[1], ops.SparseTensor)) diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py index 3fd8c2a6a90..058ce1ed248 100644 --- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py +++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google Inc. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -78,7 +78,7 @@ def register_unary_op(registered_name, operation): def _output_names(self): return "output" - def _apply_transform(self, input_tensors): + def _apply_transform(self, input_tensors, **kwargs): input_tensor = input_tensors[0] if isinstance(input_tensor, ops.SparseTensor): result = ops.SparseTensor(input_tensor.indices, diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py index 14e283cb791..7e233f33849 100644 --- a/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py @@ -208,10 +208,9 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase): tensorflow_df = df.TensorFlowDataFrame.from_csv( [data_path], batch_size=batch_size, - num_epochs=num_epochs, shuffle=False, default_values=default_values) - actual_num_batches = len(list(tensorflow_df.run())) + actual_num_batches = len(list(tensorflow_df.run(num_epochs=num_epochs))) self.assertEqual(expected_num_batches, actual_num_batches) def testFromCSVWithFeatureSpec(self):