Add kwargs to Transform.apply for num_epochs

Change: 128989804
This commit is contained in:
David Soergel 2016-08-01 07:57:35 -08:00 committed by TensorFlower Gardener
parent 52c0418614
commit 48e869f0e3
18 changed files with 57 additions and 75 deletions

View File

@ -117,10 +117,11 @@ class DataFrame(object):
value = [value] value = [value]
self.assign(**dict(zip(key, value))) self.assign(**dict(zip(key, value)))
def build(self): def build(self, **kwargs):
# We do not allow passing a cache here, because that would encourage # We do not allow passing a cache here, because that would encourage
# working around the rule that DataFrames cannot be expected to be # working around the rule that DataFrames cannot be expected to be
# synced with each other (e.g., they shuffle independently). # synced with each other (e.g., they shuffle independently).
cache = {} cache = {}
tensors = {name: c.build(cache) for name, c in self._columns.items()} tensors = {name: c.build(cache, **kwargs)
for name, c in self._columns.items()}
return tensors return tensors

View File

@ -91,7 +91,8 @@ def _build_alternate_universe(
def to_feature_columns_and_input_fn(dataframe, def to_feature_columns_and_input_fn(dataframe,
base_input_keys_with_defaults, base_input_keys_with_defaults,
feature_keys, feature_keys,
target_keys=None): target_keys=None,
**kwargs):
"""Build a list of FeatureColumns and an input_fn for use with Estimator. """Build a list of FeatureColumns and an input_fn for use with Estimator.
Args: Args:
@ -103,6 +104,7 @@ def to_feature_columns_and_input_fn(dataframe,
These may include base features and/or derived features. These may include base features and/or derived features.
target_keys: the names of columns to be used as targets. None is target_keys: the names of columns to be used as targets. None is
acceptable for unsupervised learning. acceptable for unsupervised learning.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A tuple of two elements: A tuple of two elements:
@ -155,10 +157,11 @@ def to_feature_columns_and_input_fn(dataframe,
# Build an input_fn suitable for use with Estimator. # Build an input_fn suitable for use with Estimator.
def input_fn(): def input_fn():
"""An input_fn() for feeding the given set of DataFrameColumns."""
# It's important to build all the tensors together in one DataFrame. # It's important to build all the tensors together in one DataFrame.
# If we did df.select() for both key sets and then build those, the two # If we did df.select() for both key sets and then build those, the two
# resulting DataFrames would be shuffled independently. # resulting DataFrames would be shuffled independently.
tensors = limited_dataframe.build() tensors = limited_dataframe.build(**kwargs)
base_input_features = {key: tensors[key] for key in base_input_keys} base_input_features = {key: tensors[key] for key in base_input_keys}
targets = {key: tensors[key] for key in target_keys} targets = {key: tensors[key] for key in target_keys}

View File

@ -98,7 +98,7 @@ class Series(object):
return transform_cls return transform_cls
return register return register
def build(self, cache): def build(self, cache, **kwargs):
"""Returns a Tensor.""" """Returns a Tensor."""
raise NotImplementedError() raise NotImplementedError()
@ -122,7 +122,7 @@ class PredefinedSeries(Series):
def required_base_features(self): def required_base_features(self):
return {self.name: self.feature_spec} return {self.name: self.feature_spec}
def build(self, cache): def build(self, cache, **kwargs):
try: try:
return cache[self.name] return cache[self.name]
except KeyError: except KeyError:
@ -171,10 +171,11 @@ class TransformedSeries(Series):
result.update(s.required_base_features) result.update(s.required_base_features)
return result return result
def build(self, cache=None): def build(self, cache=None, **kwargs):
if cache is None: if cache is None:
cache = {} cache = {}
all_outputs = self._transform.build_transitive(self._input_series, cache) all_outputs = self._transform.build_transitive(
self._input_series, cache, **kwargs)
return getattr(all_outputs, self._output_name) return getattr(all_outputs, self._output_name)
def __repr__(self): def __repr__(self):

View File

@ -83,7 +83,8 @@ class TensorFlowDataFrame(df.DataFrame):
graph=None, graph=None,
session=None, session=None,
start_queues=True, start_queues=True,
initialize_variables=True): initialize_variables=True,
**kwargs):
"""Builds and runs the columns of the `DataFrame` and yields batches. """Builds and runs the columns of the `DataFrame` and yields batches.
This is a generator that yields a dictionary mapping column names to This is a generator that yields a dictionary mapping column names to
@ -97,6 +98,7 @@ class TensorFlowDataFrame(df.DataFrame):
start_queues: if true, queues will be started before running and halted start_queues: if true, queues will be started before running and halted
after producting `n` batches. after producting `n` batches.
initialize_variables: if true, variables will be initialized. initialize_variables: if true, variables will be initialized.
**kwargs: Additional keyword arguments, unused here.
Yields: Yields:
A dictionary, mapping column names to the values resulting from running A dictionary, mapping column names to the values resulting from running
@ -107,7 +109,7 @@ class TensorFlowDataFrame(df.DataFrame):
with graph.as_default(): with graph.as_default():
if session is None: if session is None:
session = sess.Session() session = sess.Session()
self_built = self.build() self_built = self.build(**kwargs)
keys = list(self_built.keys()) keys = list(self_built.keys())
cols = list(self_built.values()) cols = list(self_built.values())
if initialize_variables: if initialize_variables:
@ -208,7 +210,7 @@ class TensorFlowDataFrame(df.DataFrame):
@classmethod @classmethod
def _from_csv_base(cls, filepatterns, get_default_values, has_header, def _from_csv_base(cls, filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads, enqueue_size, column_names, num_threads, enqueue_size,
batch_size, queue_capacity, min_after_dequeue, shuffle, batch_size, queue_capacity, min_after_dequeue, shuffle,
seed): seed):
"""Create a `DataFrame` from CSV files. """Create a `DataFrame` from CSV files.
@ -223,9 +225,6 @@ class TensorFlowDataFrame(df.DataFrame):
each column, given the column names. each column, given the column names.
has_header: whether or not the CSV files have headers. has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files. column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -265,7 +264,6 @@ class TensorFlowDataFrame(df.DataFrame):
reader_kwargs=reader_kwargs, reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,
@ -287,7 +285,6 @@ class TensorFlowDataFrame(df.DataFrame):
default_values, default_values,
has_header=True, has_header=True,
column_names=None, column_names=None,
num_epochs=None,
num_threads=1, num_threads=1,
enqueue_size=None, enqueue_size=None,
batch_size=32, batch_size=32,
@ -306,9 +303,6 @@ class TensorFlowDataFrame(df.DataFrame):
default_values: a list of default values for each column. default_values: a list of default values for each column.
has_header: whether or not the CSV files have headers. has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files. column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -332,7 +326,7 @@ class TensorFlowDataFrame(df.DataFrame):
return default_values return default_values
return cls._from_csv_base(filepatterns, get_default_values, has_header, return cls._from_csv_base(filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads, column_names, num_threads,
enqueue_size, batch_size, queue_capacity, enqueue_size, batch_size, queue_capacity,
min_after_dequeue, shuffle, seed) min_after_dequeue, shuffle, seed)
@ -342,7 +336,6 @@ class TensorFlowDataFrame(df.DataFrame):
feature_spec, feature_spec,
has_header=True, has_header=True,
column_names=None, column_names=None,
num_epochs=None,
num_threads=1, num_threads=1,
enqueue_size=None, enqueue_size=None,
batch_size=32, batch_size=32,
@ -362,9 +355,6 @@ class TensorFlowDataFrame(df.DataFrame):
`VarLenFeature`. `VarLenFeature`.
has_header: whether or not the CSV files have headers. has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files. column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -387,7 +377,7 @@ class TensorFlowDataFrame(df.DataFrame):
return [_get_default_value(feature_spec[name]) for name in column_names] return [_get_default_value(feature_spec[name]) for name in column_names]
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header, dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads, column_names, num_threads,
enqueue_size, batch_size, queue_capacity, enqueue_size, batch_size, queue_capacity,
min_after_dequeue, shuffle, seed) min_after_dequeue, shuffle, seed)
@ -405,7 +395,6 @@ class TensorFlowDataFrame(df.DataFrame):
filepatterns, filepatterns,
features, features,
reader_cls=io_ops.TFRecordReader, reader_cls=io_ops.TFRecordReader,
num_epochs=None,
num_threads=1, num_threads=1,
enqueue_size=None, enqueue_size=None,
batch_size=32, batch_size=32,
@ -421,9 +410,6 @@ class TensorFlowDataFrame(df.DataFrame):
`FixedLenFeature`. `FixedLenFeature`.
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
read the `Example`s. read the `Example`s.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -454,7 +440,6 @@ class TensorFlowDataFrame(df.DataFrame):
filenames, filenames,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,

View File

@ -223,13 +223,14 @@ class Transform(object):
# pylint: disable=not-callable # pylint: disable=not-callable
return self.return_type(*output_series) return self.return_type(*output_series)
def build_transitive(self, input_series, cache=None): def build_transitive(self, input_series, cache=None, **kwargs):
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's. """Apply this `Transform` to the provided `Series`, producing 'Tensor's.
Args: Args:
input_series: None, a `Series`, or a list of input `Series`, acting as input_series: None, a `Series`, or a list of input `Series`, acting as
positional arguments. positional arguments.
cache: a dict from Series reprs to Tensors. cache: a dict from Series reprs to Tensors.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of the output Tensors. A namedtuple of the output Tensors.
@ -244,7 +245,7 @@ class Transform(object):
if len(input_series) != self.input_valency: if len(input_series) != self.input_valency:
raise ValueError("Expected %s input Series but received %s." % raise ValueError("Expected %s input Series but received %s." %
(self.input_valency, len(input_series))) (self.input_valency, len(input_series)))
input_tensors = [series.build(cache) for series in input_series] input_tensors = [series.build(cache, **kwargs) for series in input_series]
# Note we cache each output individually, not just the entire output # Note we cache each output individually, not just the entire output
# tuple. This allows using the graph as the cache, since it can sensibly # tuple. This allows using the graph as the cache, since it can sensibly
@ -254,7 +255,7 @@ class Transform(object):
output_tensors = [cache.get(output_repr) for output_repr in output_reprs] output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
if None in output_tensors: if None in output_tensors:
result = self._apply_transform(input_tensors) result = self._apply_transform(input_tensors, **kwargs)
for output_name, output_repr in zip(self.output_names, output_reprs): for output_name, output_repr in zip(self.output_names, output_reprs):
cache[output_repr] = getattr(result, output_name) cache[output_repr] = getattr(result, output_name)
else: else:
@ -264,12 +265,13 @@ class Transform(object):
return result return result
@abstractmethod @abstractmethod
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.

View File

@ -72,7 +72,7 @@ class Batch(AbstractBatchTransform):
def name(self): def name(self):
return "Batch" return "Batch"
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
batched = input_ops.batch(transform_input, batched = input_ops.batch(transform_input,
batch_size=self.batch_size, batch_size=self.batch_size,
num_threads=self.num_threads, num_threads=self.num_threads,
@ -121,7 +121,7 @@ class ShuffleBatch(AbstractBatchTransform):
def seed(self): def seed(self):
return self._seed return self._seed
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
batched = input_ops.shuffle_batch(transform_input, batched = input_ops.shuffle_batch(transform_input,
batch_size=self.batch_size, batch_size=self.batch_size,
capacity=self.queue_capacity, capacity=self.queue_capacity,

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -53,7 +53,7 @@ class SeriesBinaryTransform(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
# TODO(jamieas): consider supporting sparse inputs. # TODO(jamieas): consider supporting sparse inputs.
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance( if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
input_tensors[1], ops.SparseTensor): input_tensors[1], ops.SparseTensor):
@ -87,7 +87,7 @@ class ScalarBinaryTransform(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
if isinstance(input_tensor, ops.SparseTensor): if isinstance(input_tensor, ops.SparseTensor):
result = ops.SparseTensor(input_tensor.indices, result = ops.SparseTensor(input_tensor.indices,

View File

@ -77,12 +77,13 @@ class BooleanMask(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.

View File

@ -58,7 +58,7 @@ class CSVParser(transform.Transform):
def default_values(self): def default_values(self):
return self._default_values return self._default_values
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
default_consts = [constant_op.constant(d, shape=[1]) default_consts = [constant_op.constant(d, shape=[1])
for d in self._default_values] for d in self._default_values]
parsed_values = parsing_ops.decode_csv(input_tensors[0], parsed_values = parsing_ops.decode_csv(input_tensors[0],

View File

@ -47,12 +47,13 @@ class Densify(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -50,7 +50,7 @@ class Difference(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor), pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
isinstance(input_tensors[1], ops.SparseTensor)) isinstance(input_tensors[1], ops.SparseTensor))

View File

@ -61,7 +61,7 @@ class ExampleParser(transform.Transform):
def feature_definitions(self): def feature_definitions(self):
return self._ordered_features return self._ordered_features
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
parsed_values = parsing_ops.parse_example(input_tensors[0], parsed_values = parsing_ops.parse_example(input_tensors[0],
features=self._ordered_features) features=self._ordered_features)
# pylint: disable=not-callable # pylint: disable=not-callable

View File

@ -89,7 +89,7 @@ class BaseInMemorySource(transform.Transform):
def input_valency(self): def input_valency(self):
return 0 return 0
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
queue = feeding_functions.enqueue_data(self.data, queue = feeding_functions.enqueue_data(self.data,
self.queue_capacity, self.queue_capacity,
self.shuffle, self.shuffle,

View File

@ -32,7 +32,6 @@ class ReaderSource(transform.Transform):
reader_kwargs=None, reader_kwargs=None,
enqueue_size=None, enqueue_size=None,
batch_size=1, batch_size=1,
num_epochs=None,
queue_capacity=None, queue_capacity=None,
shuffle=False, shuffle=False,
min_after_dequeue=None, min_after_dequeue=None,
@ -49,9 +48,6 @@ class ReaderSource(transform.Transform):
is constructed. is constructed.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: The desired batch size of output. Defaults to 1. batch_size: The desired batch size of output. Defaults to 1.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`. queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
shuffle: Whether records will be shuffled before returning. Defaults to shuffle: Whether records will be shuffled before returning. Defaults to
false. false.
@ -73,7 +69,6 @@ class ReaderSource(transform.Transform):
self._batch_size = batch_size self._batch_size = batch_size
self._queue_capacity = (batch_size * 10 if queue_capacity is None else self._queue_capacity = (batch_size * 10 if queue_capacity is None else
queue_capacity) queue_capacity)
self._num_epochs = num_epochs
self._shuffle = shuffle self._shuffle = shuffle
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
is None else min_after_dequeue) is None else min_after_dequeue)
@ -100,10 +95,6 @@ class ReaderSource(transform.Transform):
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
@transform.parameter
def num_epochs(self):
return self._num_epochs
@transform.parameter @transform.parameter
def queue_capacity(self): def queue_capacity(self):
return self._queue_capacity return self._queue_capacity
@ -136,11 +127,12 @@ class ReaderSource(transform.Transform):
def _output_names(self): def _output_names(self):
return ("index", "value") return ("index", "value")
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
filename_queue = input_ops.string_input_producer(self.work_units, filename_queue = input_ops.string_input_producer(
num_epochs=self.num_epochs, self.work_units,
shuffle=self.shuffle, num_epochs=kwargs.get("num_epochs"),
seed=self.seed) shuffle=self.shuffle,
seed=self.seed)
reader_ops = [] reader_ops = []
for _ in range(self.num_threads): for _ in range(self.num_threads):
reader = self._reader_cls(**self._reader_kwargs) reader = self._reader_cls(**self._reader_kwargs)
@ -174,7 +166,6 @@ def TextFileSource(file_names,
reader_kwargs=None, reader_kwargs=None,
enqueue_size=1, enqueue_size=1,
batch_size=1, batch_size=1,
num_epochs=None,
queue_capacity=None, queue_capacity=None,
shuffle=False, shuffle=False,
min_after_dequeue=None, min_after_dequeue=None,
@ -185,7 +176,6 @@ def TextFileSource(file_names,
reader_kwargs=reader_kwargs, reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,
@ -197,7 +187,6 @@ def TFRecordSource(file_names,
reader_kwargs=None, reader_kwargs=None,
enqueue_size=1, enqueue_size=1,
batch_size=1, batch_size=1,
num_epochs=None,
queue_capacity=None, queue_capacity=None,
shuffle=False, shuffle=False,
min_after_dequeue=None, min_after_dequeue=None,
@ -208,7 +197,6 @@ def TFRecordSource(file_names,
reader_kwargs=reader_kwargs, reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,

View File

@ -52,12 +52,13 @@ class Sparsify(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -44,7 +44,7 @@ class Sum(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor), pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
isinstance(input_tensors[1], ops.SparseTensor)) isinstance(input_tensors[1], ops.SparseTensor))

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -78,7 +78,7 @@ def register_unary_op(registered_name, operation):
def _output_names(self): def _output_names(self):
return "output" return "output"
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
if isinstance(input_tensor, ops.SparseTensor): if isinstance(input_tensor, ops.SparseTensor):
result = ops.SparseTensor(input_tensor.indices, result = ops.SparseTensor(input_tensor.indices,

View File

@ -208,10 +208,9 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
tensorflow_df = df.TensorFlowDataFrame.from_csv( tensorflow_df = df.TensorFlowDataFrame.from_csv(
[data_path], [data_path],
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
shuffle=False, shuffle=False,
default_values=default_values) default_values=default_values)
actual_num_batches = len(list(tensorflow_df.run())) actual_num_batches = len(list(tensorflow_df.run(num_epochs=num_epochs)))
self.assertEqual(expected_num_batches, actual_num_batches) self.assertEqual(expected_num_batches, actual_num_batches)
def testFromCSVWithFeatureSpec(self): def testFromCSVWithFeatureSpec(self):