Add kwargs to Transform.apply for num_epochs

Change: 128989804
This commit is contained in:
David Soergel 2016-08-01 07:57:35 -08:00 committed by TensorFlower Gardener
parent 52c0418614
commit 48e869f0e3
18 changed files with 57 additions and 75 deletions

View File

@ -117,10 +117,11 @@ class DataFrame(object):
value = [value]
self.assign(**dict(zip(key, value)))
def build(self):
def build(self, **kwargs):
# We do not allow passing a cache here, because that would encourage
# working around the rule that DataFrames cannot be expected to be
# synced with each other (e.g., they shuffle independently).
cache = {}
tensors = {name: c.build(cache) for name, c in self._columns.items()}
tensors = {name: c.build(cache, **kwargs)
for name, c in self._columns.items()}
return tensors

View File

@ -91,7 +91,8 @@ def _build_alternate_universe(
def to_feature_columns_and_input_fn(dataframe,
base_input_keys_with_defaults,
feature_keys,
target_keys=None):
target_keys=None,
**kwargs):
"""Build a list of FeatureColumns and an input_fn for use with Estimator.
Args:
@ -103,6 +104,7 @@ def to_feature_columns_and_input_fn(dataframe,
These may include base features and/or derived features.
target_keys: the names of columns to be used as targets. None is
acceptable for unsupervised learning.
**kwargs: Additional keyword arguments, unused here.
Returns:
A tuple of two elements:
@ -155,10 +157,11 @@ def to_feature_columns_and_input_fn(dataframe,
# Build an input_fn suitable for use with Estimator.
def input_fn():
"""An input_fn() for feeding the given set of DataFrameColumns."""
# It's important to build all the tensors together in one DataFrame.
# If we did df.select() for both key sets and then build those, the two
# resulting DataFrames would be shuffled independently.
tensors = limited_dataframe.build()
tensors = limited_dataframe.build(**kwargs)
base_input_features = {key: tensors[key] for key in base_input_keys}
targets = {key: tensors[key] for key in target_keys}

View File

@ -98,7 +98,7 @@ class Series(object):
return transform_cls
return register
def build(self, cache):
def build(self, cache, **kwargs):
"""Returns a Tensor."""
raise NotImplementedError()
@ -122,7 +122,7 @@ class PredefinedSeries(Series):
def required_base_features(self):
return {self.name: self.feature_spec}
def build(self, cache):
def build(self, cache, **kwargs):
try:
return cache[self.name]
except KeyError:
@ -171,10 +171,11 @@ class TransformedSeries(Series):
result.update(s.required_base_features)
return result
def build(self, cache=None):
def build(self, cache=None, **kwargs):
if cache is None:
cache = {}
all_outputs = self._transform.build_transitive(self._input_series, cache)
all_outputs = self._transform.build_transitive(
self._input_series, cache, **kwargs)
return getattr(all_outputs, self._output_name)
def __repr__(self):

View File

@ -83,7 +83,8 @@ class TensorFlowDataFrame(df.DataFrame):
graph=None,
session=None,
start_queues=True,
initialize_variables=True):
initialize_variables=True,
**kwargs):
"""Builds and runs the columns of the `DataFrame` and yields batches.
This is a generator that yields a dictionary mapping column names to
@ -97,6 +98,7 @@ class TensorFlowDataFrame(df.DataFrame):
start_queues: if true, queues will be started before running and halted
after producting `n` batches.
initialize_variables: if true, variables will be initialized.
**kwargs: Additional keyword arguments, unused here.
Yields:
A dictionary, mapping column names to the values resulting from running
@ -107,7 +109,7 @@ class TensorFlowDataFrame(df.DataFrame):
with graph.as_default():
if session is None:
session = sess.Session()
self_built = self.build()
self_built = self.build(**kwargs)
keys = list(self_built.keys())
cols = list(self_built.values())
if initialize_variables:
@ -208,7 +210,7 @@ class TensorFlowDataFrame(df.DataFrame):
@classmethod
def _from_csv_base(cls, filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads, enqueue_size,
column_names, num_threads, enqueue_size,
batch_size, queue_capacity, min_after_dequeue, shuffle,
seed):
"""Create a `DataFrame` from CSV files.
@ -223,9 +225,6 @@ class TensorFlowDataFrame(df.DataFrame):
each column, given the column names.
has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation.
batch_size: desired batch size.
@ -265,7 +264,6 @@ class TensorFlowDataFrame(df.DataFrame):
reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size,
batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
shuffle=shuffle,
min_after_dequeue=min_after_dequeue,
@ -287,7 +285,6 @@ class TensorFlowDataFrame(df.DataFrame):
default_values,
has_header=True,
column_names=None,
num_epochs=None,
num_threads=1,
enqueue_size=None,
batch_size=32,
@ -306,9 +303,6 @@ class TensorFlowDataFrame(df.DataFrame):
default_values: a list of default values for each column.
has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation.
batch_size: desired batch size.
@ -332,7 +326,7 @@ class TensorFlowDataFrame(df.DataFrame):
return default_values
return cls._from_csv_base(filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads,
column_names, num_threads,
enqueue_size, batch_size, queue_capacity,
min_after_dequeue, shuffle, seed)
@ -342,7 +336,6 @@ class TensorFlowDataFrame(df.DataFrame):
feature_spec,
has_header=True,
column_names=None,
num_epochs=None,
num_threads=1,
enqueue_size=None,
batch_size=32,
@ -362,9 +355,6 @@ class TensorFlowDataFrame(df.DataFrame):
`VarLenFeature`.
has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation.
batch_size: desired batch size.
@ -387,7 +377,7 @@ class TensorFlowDataFrame(df.DataFrame):
return [_get_default_value(feature_spec[name]) for name in column_names]
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads,
column_names, num_threads,
enqueue_size, batch_size, queue_capacity,
min_after_dequeue, shuffle, seed)
@ -405,7 +395,6 @@ class TensorFlowDataFrame(df.DataFrame):
filepatterns,
features,
reader_cls=io_ops.TFRecordReader,
num_epochs=None,
num_threads=1,
enqueue_size=None,
batch_size=32,
@ -421,9 +410,6 @@ class TensorFlowDataFrame(df.DataFrame):
`FixedLenFeature`.
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
read the `Example`s.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation.
batch_size: desired batch size.
@ -454,7 +440,6 @@ class TensorFlowDataFrame(df.DataFrame):
filenames,
enqueue_size=enqueue_size,
batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
shuffle=shuffle,
min_after_dequeue=min_after_dequeue,

View File

@ -223,13 +223,14 @@ class Transform(object):
# pylint: disable=not-callable
return self.return_type(*output_series)
def build_transitive(self, input_series, cache=None):
def build_transitive(self, input_series, cache=None, **kwargs):
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's.
Args:
input_series: None, a `Series`, or a list of input `Series`, acting as
positional arguments.
cache: a dict from Series reprs to Tensors.
**kwargs: Additional keyword arguments, unused here.
Returns:
A namedtuple of the output Tensors.
@ -244,7 +245,7 @@ class Transform(object):
if len(input_series) != self.input_valency:
raise ValueError("Expected %s input Series but received %s." %
(self.input_valency, len(input_series)))
input_tensors = [series.build(cache) for series in input_series]
input_tensors = [series.build(cache, **kwargs) for series in input_series]
# Note we cache each output individually, not just the entire output
# tuple. This allows using the graph as the cache, since it can sensibly
@ -254,7 +255,7 @@ class Transform(object):
output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
if None in output_tensors:
result = self._apply_transform(input_tensors)
result = self._apply_transform(input_tensors, **kwargs)
for output_name, output_repr in zip(self.output_names, output_reprs):
cache[output_repr] = getattr(result, output_name)
else:
@ -264,12 +265,13 @@ class Transform(object):
return result
@abstractmethod
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`.
Args:
input_tensors: a list of Tensors representing the input to
input_tensors: a list of Tensors representing the input to
the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns:
A namedtuple of Tensors representing the transformed output.

View File

@ -72,7 +72,7 @@ class Batch(AbstractBatchTransform):
def name(self):
return "Batch"
def _apply_transform(self, transform_input):
def _apply_transform(self, transform_input, **kwargs):
batched = input_ops.batch(transform_input,
batch_size=self.batch_size,
num_threads=self.num_threads,
@ -121,7 +121,7 @@ class ShuffleBatch(AbstractBatchTransform):
def seed(self):
return self._seed
def _apply_transform(self, transform_input):
def _apply_transform(self, transform_input, **kwargs):
batched = input_ops.shuffle_batch(transform_input,
batch_size=self.batch_size,
capacity=self.queue_capacity,

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -53,7 +53,7 @@ class SeriesBinaryTransform(transform.Transform):
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
# TODO(jamieas): consider supporting sparse inputs.
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
input_tensors[1], ops.SparseTensor):
@ -87,7 +87,7 @@ class ScalarBinaryTransform(transform.Transform):
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
input_tensor = input_tensors[0]
if isinstance(input_tensor, ops.SparseTensor):
result = ops.SparseTensor(input_tensor.indices,

View File

@ -77,12 +77,13 @@ class BooleanMask(transform.Transform):
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`.
Args:
input_tensors: a list of Tensors representing the input to
input_tensors: a list of Tensors representing the input to
the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns:
A namedtuple of Tensors representing the transformed output.

View File

@ -58,7 +58,7 @@ class CSVParser(transform.Transform):
def default_values(self):
return self._default_values
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
default_consts = [constant_op.constant(d, shape=[1])
for d in self._default_values]
parsed_values = parsing_ops.decode_csv(input_tensors[0],

View File

@ -47,12 +47,13 @@ class Densify(transform.Transform):
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`.
Args:
input_tensors: a list of Tensors representing the input to
input_tensors: a list of Tensors representing the input to
the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns:
A namedtuple of Tensors representing the transformed output.

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -50,7 +50,7 @@ class Difference(transform.Transform):
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
isinstance(input_tensors[1], ops.SparseTensor))

View File

@ -61,7 +61,7 @@ class ExampleParser(transform.Transform):
def feature_definitions(self):
return self._ordered_features
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
parsed_values = parsing_ops.parse_example(input_tensors[0],
features=self._ordered_features)
# pylint: disable=not-callable

View File

@ -89,7 +89,7 @@ class BaseInMemorySource(transform.Transform):
def input_valency(self):
return 0
def _apply_transform(self, transform_input):
def _apply_transform(self, transform_input, **kwargs):
queue = feeding_functions.enqueue_data(self.data,
self.queue_capacity,
self.shuffle,

View File

@ -32,7 +32,6 @@ class ReaderSource(transform.Transform):
reader_kwargs=None,
enqueue_size=None,
batch_size=1,
num_epochs=None,
queue_capacity=None,
shuffle=False,
min_after_dequeue=None,
@ -49,9 +48,6 @@ class ReaderSource(transform.Transform):
is constructed.
enqueue_size: block size for each read operation.
batch_size: The desired batch size of output. Defaults to 1.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
shuffle: Whether records will be shuffled before returning. Defaults to
false.
@ -73,7 +69,6 @@ class ReaderSource(transform.Transform):
self._batch_size = batch_size
self._queue_capacity = (batch_size * 10 if queue_capacity is None else
queue_capacity)
self._num_epochs = num_epochs
self._shuffle = shuffle
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
is None else min_after_dequeue)
@ -100,10 +95,6 @@ class ReaderSource(transform.Transform):
def batch_size(self):
return self._batch_size
@transform.parameter
def num_epochs(self):
return self._num_epochs
@transform.parameter
def queue_capacity(self):
return self._queue_capacity
@ -136,11 +127,12 @@ class ReaderSource(transform.Transform):
def _output_names(self):
return ("index", "value")
def _apply_transform(self, transform_input):
filename_queue = input_ops.string_input_producer(self.work_units,
num_epochs=self.num_epochs,
shuffle=self.shuffle,
seed=self.seed)
def _apply_transform(self, transform_input, **kwargs):
filename_queue = input_ops.string_input_producer(
self.work_units,
num_epochs=kwargs.get("num_epochs"),
shuffle=self.shuffle,
seed=self.seed)
reader_ops = []
for _ in range(self.num_threads):
reader = self._reader_cls(**self._reader_kwargs)
@ -174,7 +166,6 @@ def TextFileSource(file_names,
reader_kwargs=None,
enqueue_size=1,
batch_size=1,
num_epochs=None,
queue_capacity=None,
shuffle=False,
min_after_dequeue=None,
@ -185,7 +176,6 @@ def TextFileSource(file_names,
reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size,
batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
shuffle=shuffle,
min_after_dequeue=min_after_dequeue,
@ -197,7 +187,6 @@ def TFRecordSource(file_names,
reader_kwargs=None,
enqueue_size=1,
batch_size=1,
num_epochs=None,
queue_capacity=None,
shuffle=False,
min_after_dequeue=None,
@ -208,7 +197,6 @@ def TFRecordSource(file_names,
reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size,
batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
shuffle=shuffle,
min_after_dequeue=min_after_dequeue,

View File

@ -52,12 +52,13 @@ class Sparsify(transform.Transform):
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`.
Args:
input_tensors: a list of Tensors representing the input to
input_tensors: a list of Tensors representing the input to
the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns:
A namedtuple of Tensors representing the transformed output.

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -44,7 +44,7 @@ class Sum(transform.Transform):
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
isinstance(input_tensors[1], ops.SparseTensor))

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -78,7 +78,7 @@ def register_unary_op(registered_name, operation):
def _output_names(self):
return "output"
def _apply_transform(self, input_tensors):
def _apply_transform(self, input_tensors, **kwargs):
input_tensor = input_tensors[0]
if isinstance(input_tensor, ops.SparseTensor):
result = ops.SparseTensor(input_tensor.indices,

View File

@ -208,10 +208,9 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
tensorflow_df = df.TensorFlowDataFrame.from_csv(
[data_path],
batch_size=batch_size,
num_epochs=num_epochs,
shuffle=False,
default_values=default_values)
actual_num_batches = len(list(tensorflow_df.run()))
actual_num_batches = len(list(tensorflow_df.run(num_epochs=num_epochs)))
self.assertEqual(expected_num_batches, actual_num_batches)
def testFromCSVWithFeatureSpec(self):